提交 dbb60572 编写于 作者: F fengjiayi 提交者: GitHub

Refine Python operator input/output checks (#4803)

上级 5d9ce046
...@@ -176,6 +176,18 @@ class Operator(object): ...@@ -176,6 +176,18 @@ class Operator(object):
proto = OpProtoHolder.instance().get_op_proto(type) proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None: if inputs is not None:
given = set()
need = set()
for n in inputs:
given.add(n)
for m in proto.inputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
for in_proto in proto.inputs: for in_proto in proto.inputs:
in_argus = inputs[in_proto.name] in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list): if not isinstance(in_argus, list):
...@@ -190,6 +202,18 @@ class Operator(object): ...@@ -190,6 +202,18 @@ class Operator(object):
self.desc.set_input(in_proto.name, in_argu_names) self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None: if outputs is not None:
given = set()
need = set()
for n in outputs:
given.add(n)
for m in proto.outputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
for out_proto in proto.outputs: for out_proto in proto.outputs:
out_argus = outputs[out_proto.name] out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list): if not isinstance(out_argus, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册