From e621ff39e5be2a51acfaa4a36c8fd2bb7315821d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 10 Oct 2017 18:04:54 -0700 Subject: [PATCH] Follow comments --- python/paddle/v2/framework/graph.py | 23 +++++++++++++++---- .../v2/framework/tests/test_operator_desc.py | 13 +++++++++-- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index e4052b15049..52c2f9a05cf 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -142,15 +142,25 @@ class OpProtoHolder(object): class Operator(object): - def __init__(self, block, desc, type, inputs=None, outputs=None, + def __init__(self, + block, + desc, + type=None, + inputs=None, + outputs=None, attrs=None): self.block = block self.desc = desc - self.proto = OpProtoHolder.instance().get_op_proto(type) + if len(self.desc.type()) != 0: + return + if type is None: + raise ValueError( + "`type` to initilized an Operator can not be None.") self.desc.set_type(type) + proto = OpProtoHolder.instance().get_op_proto(type) if inputs is not None: - for in_proto in self.proto.inputs: + for in_proto in proto.inputs: in_argus = inputs[in_proto.name] if not isinstance(in_argus, list): in_argus = [in_argus] @@ -164,7 +174,7 @@ class Operator(object): self.desc.set_input(in_proto.name, in_argu_names) if outputs is not None: - for out_proto in self.proto.outputs: + for out_proto in proto.outputs: out_argus = outputs[out_proto.name] if not isinstance(out_argus, list): out_argus = [out_argus] @@ -175,10 +185,11 @@ class Operator(object): out_argu_names = [] for argu in out_argus: out_argu_names.append(argu.name) + argu.op = self self.desc.set_output(out_proto.name, out_argu_names) if attrs is not None: - for attr in self.proto.attrs: + for attr in proto.attrs: attr_name = attr.name if not attr_name in attrs: continue @@ -187,6 +198,8 @@ class Operator(object): else: self.desc.set_block_attr(attr_name, attrs[attr_name].desc) + self.desc.infer_shape(self.block.desc) + @property def type(self): return self.desc.type() diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py index c46c030b2e6..b9021ffc22e 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -6,11 +6,18 @@ import paddle.v2.framework.core as core class TestOperator(unittest.TestCase): def test_error_type(self): block = g_program.create_block() + try: + block.append_op() + self.assertFail() + except ValueError as v_err: + self.assertEqual( + v_err.message, + "`type` to initilized an Operator can not be None.") try: block.append_op(type="no_such_op") self.assertFail() - except AssertionError as err: - self.assertEqual(err.message, + except AssertionError as a_err: + self.assertEqual(a_err.message, "Operator \"no_such_op\" has not been registered.") def test_op_desc_creation(self): @@ -37,6 +44,7 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_out.op, mul_op) def test_mult_input(self): block = g_program.current_block() @@ -57,6 +65,7 @@ class TestOperator(unittest.TestCase): self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) self.assertEqual(sum_op.output_names, ["Out"]) self.assertEqual(sum_op.output("Out"), ["sum.out"]) + self.assertEqual(sum_out.op, sum_op) if __name__ == '__main__': -- GitLab