diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 10e5726a85f435b08997083094223ac2a0a15b61..01cd9982dc1c8d9869e59c55d0061abef91919ef 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -176,6 +176,18 @@ class Operator(object): proto = OpProtoHolder.instance().get_op_proto(type) if inputs is not None: + given = set() + need = set() + for n in inputs: + given.add(n) + for m in proto.inputs: + need.add(m.name) + if not given == need: + raise ValueError( + "Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]" + % (type, ", ".join(str(e) for e in need), ", ".join( + str(e) for e in given))) + for in_proto in proto.inputs: in_argus = inputs[in_proto.name] if not isinstance(in_argus, list): @@ -190,6 +202,18 @@ class Operator(object): self.desc.set_input(in_proto.name, in_argu_names) if outputs is not None: + given = set() + need = set() + for n in outputs: + given.add(n) + for m in proto.outputs: + need.add(m.name) + if not given == need: + raise ValueError( + "Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]" + % (type, ", ".join(str(e) for e in need), ", ".join( + str(e) for e in given))) + for out_proto in proto.outputs: out_argus = outputs[out_proto.name] if not isinstance(out_argus, list):