提交 90fa6db6 编写于 作者: F fengjiayi

Add infer_shape to Python Operator and fix bugs

上级 bf26cc53
...@@ -198,6 +198,7 @@ class Operator(object): ...@@ -198,6 +198,7 @@ class Operator(object):
else: else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
self.desc.check_attrs()
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
@property @property
......
...@@ -40,10 +40,14 @@ class TestOperator(unittest.TestCase): ...@@ -40,10 +40,14 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.input("Y"), ["mul.y"]) self.assertEqual(mul_op.input("Y"), ["mul.y"])
self.assertEqual(mul_op.output_names, ["Out"]) self.assertEqual(mul_op.output_names, ["Out"])
self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual(mul_op.attr_names, ["x_num_col_dims"]) self.assertEqual(
set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
self.assertEqual(mul_op.has_attr("y_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
self.assertEqual(mul_out.op, mul_op) self.assertEqual(mul_out.op, mul_op)
def test_mult_input(self): def test_mult_input(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册