From dbb60572a56096c02acab14cd2783df720010a6b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sat, 14 Oct 2017 14:55:11 -0700 Subject: [PATCH] Refine Python operator input/output checks (#4803) --- python/paddle/v2/framework/framework.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 10e5726a85f..01cd9982dc1 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -176,6 +176,18 @@ class Operator(object): proto = OpProtoHolder.instance().get_op_proto(type) if inputs is not None: + given = set() + need = set() + for n in inputs: + given.add(n) + for m in proto.inputs: + need.add(m.name) + if not given == need: + raise ValueError( + "Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]" + % (type, ", ".join(str(e) for e in need), ", ".join( + str(e) for e in given))) + for in_proto in proto.inputs: in_argus = inputs[in_proto.name] if not isinstance(in_argus, list): @@ -190,6 +202,18 @@ class Operator(object): self.desc.set_input(in_proto.name, in_argu_names) if outputs is not None: + given = set() + need = set() + for n in outputs: + given.add(n) + for m in proto.outputs: + need.add(m.name) + if not given == need: + raise ValueError( + "Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]" + % (type, ", ".join(str(e) for e in need), ", ".join( + str(e) for e in given))) + for out_proto in proto.outputs: out_argus = outputs[out_proto.name] if not isinstance(out_argus, list): -- GitLab