Skip to content

Commit

Permalink
[Test] add tet for InsertIdentityOnAllTopLevelIO, remove onnx dumps
Browse files Browse the repository at this point in the history
  • Loading branch information
maltanar committed Jan 13, 2025
1 parent 358b58a commit 1aac6b6
Showing 1 changed file with 11 additions and 22 deletions.
33 changes: 11 additions & 22 deletions tests/transformation/test_insert_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.insert import InsertIdentity
from qonnx.transformation.insert import InsertIdentity, InsertIdentityOnAllTopLevelIO


@pytest.fixture
Expand All @@ -49,9 +49,16 @@ def simple_model():
return model


def save_transformed_model(model, test_name):
output_path = f"{test_name}.onnx"
model.save(output_path)
def test_insert_identity_on_all_top_level_io(simple_model):
orig_top_inp_names = [inp.name for inp in simple_model.graph.input]
orig_top_out_names = [out.name for out in simple_model.graph.output]
model = simple_model.transform(InsertIdentityOnAllTopLevelIO())
for inp in orig_top_inp_names:
assert model.find_consumer(inp).op_type == "Identity"
for out in orig_top_out_names:
assert model.find_producer(out).op_type == "Identity"
assert orig_top_inp_names == [inp.name for inp in model.graph.input]
assert orig_top_out_names == [out.name for out in model.graph.output]


def test_insert_identity_before_input(simple_model):
Expand All @@ -63,9 +70,6 @@ def test_insert_identity_before_input(simple_model):
assert identity_node is not None
assert identity_node.op_type == "Identity"

# Save the transformed model
save_transformed_model(model, "test_insert_identity_before_input")


def test_insert_identity_after_input(simple_model):
# Apply the transformation
Expand All @@ -76,9 +80,6 @@ def test_insert_identity_after_input(simple_model):
assert identity_node is not None
assert identity_node.op_type == "Identity"

# Save the transformed model
save_transformed_model(model, "test_insert_identity_after_input")


def test_insert_identity_before_intermediate(simple_model):
# Apply the transformation
Expand All @@ -89,9 +90,6 @@ def test_insert_identity_before_intermediate(simple_model):
assert identity_node is not None
assert identity_node.op_type == "Identity"

# Save the transformed model
save_transformed_model(model, "test_insert_identity_before_intermediate")


def test_insert_identity_after_intermediate(simple_model):
# Apply the transformation
Expand All @@ -102,9 +100,6 @@ def test_insert_identity_after_intermediate(simple_model):
assert identity_node is not None
assert identity_node.op_type == "Identity"

# Save the transformed model
save_transformed_model(model, "test_insert_identity_after_intermediate")


def test_insert_identity_before_output(simple_model):
# Apply the transformation
Expand All @@ -115,9 +110,6 @@ def test_insert_identity_before_output(simple_model):
assert identity_node is not None
assert identity_node.op_type == "Identity"

# Save the transformed model
save_transformed_model(model, "test_insert_identity_before_output")


def test_insert_identity_after_output(simple_model):
# Apply the transformation
Expand All @@ -128,9 +120,6 @@ def test_insert_identity_after_output(simple_model):
assert identity_node is not None
assert identity_node.op_type == "Identity"

# Save the transformed model
save_transformed_model(model, "test_insert_identity_after_output")


def test_tensor_not_found(simple_model):
# Apply the transformation with a non-existent tensor
Expand Down

0 comments on commit 1aac6b6

Please sign in to comment.