From 90fa6db6bd193100b67ed62b37615ac2f38e4011 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 11 Oct 2017 10:27:05 -0700 Subject: [PATCH] Add infer_shape to Python Operator and fix bugs --- python/paddle/v2/framework/graph.py | 1 + python/paddle/v2/framework/tests/test_operator_desc.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 52c2f9a05cf..2afbd0c8315 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -198,6 +198,7 @@ class Operator(object): else: self.desc.set_block_attr(attr_name, attrs[attr_name].desc) + self.desc.check_attrs() self.desc.infer_shape(self.block.desc) @property diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py index b9021ffc22e..ec6c6bc1833 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -40,10 +40,14 @@ class TestOperator(unittest.TestCase): self.assertEqual(mul_op.input("Y"), ["mul.y"]) self.assertEqual(mul_op.output_names, ["Out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"]) - self.assertEqual(mul_op.attr_names, ["x_num_col_dims"]) + self.assertEqual( + set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_op.has_attr("y_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("y_num_col_dims"), 1) self.assertEqual(mul_out.op, mul_op) def test_mult_input(self): -- GitLab