提交 e621ff39 编写于 作者: F fengjiayi

Follow comments

上级 afaac789
......@@ -142,15 +142,25 @@ class OpProtoHolder(object):
class Operator(object):
def __init__(self, block, desc, type, inputs=None, outputs=None,
def __init__(self,
block,
desc,
type=None,
inputs=None,
outputs=None,
attrs=None):
self.block = block
self.desc = desc
self.proto = OpProtoHolder.instance().get_op_proto(type)
if len(self.desc.type()) != 0:
return
if type is None:
raise ValueError(
"`type` to initilized an Operator can not be None.")
self.desc.set_type(type)
proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None:
for in_proto in self.proto.inputs:
for in_proto in proto.inputs:
in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
in_argus = [in_argus]
......@@ -164,7 +174,7 @@ class Operator(object):
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None:
for out_proto in self.proto.outputs:
for out_proto in proto.outputs:
out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list):
out_argus = [out_argus]
......@@ -175,10 +185,11 @@ class Operator(object):
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name)
argu.op = self
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None:
for attr in self.proto.attrs:
for attr in proto.attrs:
attr_name = attr.name
if not attr_name in attrs:
continue
......@@ -187,6 +198,8 @@ class Operator(object):
else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
self.desc.infer_shape(self.block.desc)
@property
def type(self):
return self.desc.type()
......
......@@ -6,11 +6,18 @@ import paddle.v2.framework.core as core
class TestOperator(unittest.TestCase):
def test_error_type(self):
block = g_program.create_block()
try:
block.append_op()
self.assertFail()
except ValueError as v_err:
self.assertEqual(
v_err.message,
"`type` to initilized an Operator can not be None.")
try:
block.append_op(type="no_such_op")
self.assertFail()
except AssertionError as err:
self.assertEqual(err.message,
except AssertionError as a_err:
self.assertEqual(a_err.message,
"Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self):
......@@ -37,6 +44,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
self.assertEqual(mul_out.op, mul_op)
def test_mult_input(self):
block = g_program.current_block()
......@@ -57,6 +65,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"])
self.assertEqual(sum_out.op, sum_op)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册