提交 86451b30 编写于 作者: F fengjiayi

Update

上级 a4278559
......@@ -77,18 +77,44 @@ class Operator(object):
self.desc = desc
self.proto = OpProtoHolder.instance().get_op_proto(type)
self.desc.set_type(type)
if inputs is not None:
for in_proto in self.proto.inputs:
in_argu = inputs[in_proto.name]
if is_str(in_argu):
in_argu = [in_argu]
in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
in_argus = [in_argus]
if not in_proto.duplicable and len(in_argus) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given." %
(in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name())
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None:
for k, v in outputs.iteritems():
self.proto.set_output(k, v)
for out_proto in self.proto.outputs:
out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list):
out_argus = [out_argus]
if not out_proto.duplicable and len(out_argus) > 1:
raise ValueError(
"Output %s expects only one output, but %d are given." %
(out_proto.name, len(out_argus)))
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name())
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None:
# TODO
pass
for attr in self.proto.attrs:
attr_name = attr.name
if not attr_name in attrs:
continue
if not isinstance(attrs[attr_name], Block):
self.desc.set_attr(attr_name, attrs[attr_name])
else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
# TODO: Getters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册