提交 7464bd29 编写于 作者: X Xin Pan

polish

test=develop
上级 35e6b5e1
...@@ -612,6 +612,7 @@ class Operator(core.OpBase): ...@@ -612,6 +612,7 @@ class Operator(core.OpBase):
return True return True
return False return False
self.inputs = []
if inputs is not None: if inputs is not None:
for in_proto in proto.inputs: for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name) found = find_name(inputs, in_proto.name)
...@@ -638,6 +639,13 @@ class Operator(core.OpBase): ...@@ -638,6 +639,13 @@ class Operator(core.OpBase):
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
for inp in inputs.values():
if isinstance(inp, Variable):
self.inputs.append(inp)
elif isinstance(inp, list) or isinstance(inp, tuple):
self.inputs.extend(inp[:])
self.outputs = []
if outputs is not None: if outputs is not None:
given = set() given = set()
need = set() need = set()
...@@ -666,20 +674,11 @@ class Operator(core.OpBase): ...@@ -666,20 +674,11 @@ class Operator(core.OpBase):
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
input_vars = [] for out in outputs.values():
for inp in inputs.values(): if isinstance(out, Variable):
if isinstance(inp, Variable): self.outputs.append(out)
input_vars.append(inp) elif isinstance(out, list) or isinstance(out, tuple):
elif isinstance(inp, list) or isinstance(inp, tuple): self.outputs.extend(out[:])
input_vars.extend(inp[:])
self.inputs = input_vars
output_vars = []
for out in outputs.values():
if isinstance(out, Variable):
output_vars.append(out)
elif isinstance(out, list) or isinstance(out, tuple):
output_vars.extend(out[:])
self.outputs = output_vars
if op_attrs is not None: if op_attrs is not None:
if not isinstance(op_attrs, dict): if not isinstance(op_attrs, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册