From 86451b3064a320ae65f0ba9c23517e2852df0491 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 10 Oct 2017 11:01:57 -0700 Subject: [PATCH] Update --- python/paddle/v2/framework/graph.py | 40 ++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index c752f5e7ed..b9356aaf89 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -77,18 +77,44 @@ class Operator(object): self.desc = desc self.proto = OpProtoHolder.instance().get_op_proto(type) self.desc.set_type(type) + if inputs is not None: for in_proto in self.proto.inputs: - in_argu = inputs[in_proto.name] - if is_str(in_argu): - in_argu = [in_argu] + in_argus = inputs[in_proto.name] + if not isinstance(in_argus, list): + in_argus = [in_argus] + if not in_proto.duplicable and len(in_argus) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." % + (in_proto.name, len(in_argus))) + in_argu_names = [] + for argu in in_argus: + in_argu_names.append(argu.name()) + self.desc.set_input(in_proto.name, in_argu_names) if outputs is not None: - for k, v in outputs.iteritems(): - self.proto.set_output(k, v) + for out_proto in self.proto.outputs: + out_argus = outputs[out_proto.name] + if not isinstance(out_argus, list): + out_argus = [out_argus] + if not out_proto.duplicable and len(out_argus) > 1: + raise ValueError( + "Output %s expects only one output, but %d are given." % + (out_proto.name, len(out_argus))) + out_argu_names = [] + for argu in out_argus: + out_argu_names.append(argu.name()) + self.desc.set_output(out_proto.name, out_argu_names) + if attrs is not None: - # TODO - pass + for attr in self.proto.attrs: + attr_name = attr.name + if not attr_name in attrs: + continue + if not isinstance(attrs[attr_name], Block): + self.desc.set_attr(attr_name, attrs[attr_name]) + else: + self.desc.set_block_attr(attr_name, attrs[attr_name].desc) # TODO: Getters -- GitLab