提交 e621ff39 编写于 作者: F fengjiayi

Follow comments

上级 afaac789
...@@ -142,15 +142,25 @@ class OpProtoHolder(object): ...@@ -142,15 +142,25 @@ class OpProtoHolder(object):
class Operator(object): class Operator(object):
def __init__(self, block, desc, type, inputs=None, outputs=None, def __init__(self,
block,
desc,
type=None,
inputs=None,
outputs=None,
attrs=None): attrs=None):
self.block = block self.block = block
self.desc = desc self.desc = desc
self.proto = OpProtoHolder.instance().get_op_proto(type) if len(self.desc.type()) != 0:
return
if type is None:
raise ValueError(
"`type` to initilized an Operator can not be None.")
self.desc.set_type(type) self.desc.set_type(type)
proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None: if inputs is not None:
for in_proto in self.proto.inputs: for in_proto in proto.inputs:
in_argus = inputs[in_proto.name] in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list): if not isinstance(in_argus, list):
in_argus = [in_argus] in_argus = [in_argus]
...@@ -164,7 +174,7 @@ class Operator(object): ...@@ -164,7 +174,7 @@ class Operator(object):
self.desc.set_input(in_proto.name, in_argu_names) self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None: if outputs is not None:
for out_proto in self.proto.outputs: for out_proto in proto.outputs:
out_argus = outputs[out_proto.name] out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list): if not isinstance(out_argus, list):
out_argus = [out_argus] out_argus = [out_argus]
...@@ -175,10 +185,11 @@ class Operator(object): ...@@ -175,10 +185,11 @@ class Operator(object):
out_argu_names = [] out_argu_names = []
for argu in out_argus: for argu in out_argus:
out_argu_names.append(argu.name) out_argu_names.append(argu.name)
argu.op = self
self.desc.set_output(out_proto.name, out_argu_names) self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None: if attrs is not None:
for attr in self.proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if not attr_name in attrs: if not attr_name in attrs:
continue continue
...@@ -187,6 +198,8 @@ class Operator(object): ...@@ -187,6 +198,8 @@ class Operator(object):
else: else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
self.desc.infer_shape(self.block.desc)
@property @property
def type(self): def type(self):
return self.desc.type() return self.desc.type()
......
...@@ -6,11 +6,18 @@ import paddle.v2.framework.core as core ...@@ -6,11 +6,18 @@ import paddle.v2.framework.core as core
class TestOperator(unittest.TestCase): class TestOperator(unittest.TestCase):
def test_error_type(self): def test_error_type(self):
block = g_program.create_block() block = g_program.create_block()
try:
block.append_op()
self.assertFail()
except ValueError as v_err:
self.assertEqual(
v_err.message,
"`type` to initilized an Operator can not be None.")
try: try:
block.append_op(type="no_such_op") block.append_op(type="no_such_op")
self.assertFail() self.assertFail()
except AssertionError as err: except AssertionError as a_err:
self.assertEqual(err.message, self.assertEqual(a_err.message,
"Operator \"no_such_op\" has not been registered.") "Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self): def test_op_desc_creation(self):
...@@ -37,6 +44,7 @@ class TestOperator(unittest.TestCase): ...@@ -37,6 +44,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
self.assertEqual(mul_out.op, mul_op)
def test_mult_input(self): def test_mult_input(self):
block = g_program.current_block() block = g_program.current_block()
...@@ -57,6 +65,7 @@ class TestOperator(unittest.TestCase): ...@@ -57,6 +65,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"]) self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"]) self.assertEqual(sum_op.output("Out"), ["sum.out"])
self.assertEqual(sum_out.op, sum_op)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册