diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 6f2a76a9835696430ef9a23630736c5d151d8db2..b9356aaf899829fb000956eb86dfc4b082af0b70 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -1,9 +1,44 @@ import paddle.v2.framework.core as core +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 import collections __all__ = ['Block', 'Variable', 'Program', 'Operator'] +def get_all_op_protos(): + """ + Get all registered op proto from PaddlePaddle C++ end. + :return: A list of registered OpProto. + """ + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values + + +class OpProtoHolder(object): + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' + op_protos = get_all_op_protos() + self.op_proto_map = {} + for proto in op_protos: + sefl.op_proto_map[proto.type] = proto + + def get_op_proto(self, type): + assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type + return self.op_proto_map[type] + + class Variable(object): def __init__(self, block, name=None, shape=None, dtype=None, lod_level=None): @@ -36,27 +71,50 @@ class Variable(object): class Operator(object): - def __init__(self, - block, - proto, - type=None, - inputs=None, - outputs=None, + def __init__(self, block, desc, type, inputs=None, outputs=None, attrs=None): self.block = block - self.proto = proto - if type is not None: - # TODO. - pass + self.desc = desc + self.proto = OpProtoHolder.instance().get_op_proto(type) + self.desc.set_type(type) + if inputs is not None: - # TODO - pass + for in_proto in self.proto.inputs: + in_argus = inputs[in_proto.name] + if not isinstance(in_argus, list): + in_argus = [in_argus] + if not in_proto.duplicable and len(in_argus) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." % + (in_proto.name, len(in_argus))) + in_argu_names = [] + for argu in in_argus: + in_argu_names.append(argu.name()) + self.desc.set_input(in_proto.name, in_argu_names) + if outputs is not None: - # TODO - pass + for out_proto in self.proto.outputs: + out_argus = outputs[out_proto.name] + if not isinstance(out_argus, list): + out_argus = [out_argus] + if not out_proto.duplicable and len(out_argus) > 1: + raise ValueError( + "Output %s expects only one output, but %d are given." % + (out_proto.name, len(out_argus))) + out_argu_names = [] + for argu in out_argus: + out_argu_names.append(argu.name()) + self.desc.set_output(out_proto.name, out_argu_names) + if attrs is not None: - # TODO - pass + for attr in self.proto.attrs: + attr_name = attr.name + if not attr_name in attrs: + continue + if not isinstance(attrs[attr_name], Block): + self.desc.set_attr(attr_name, attrs[attr_name]) + else: + self.desc.set_block_attr(attr_name, attrs[attr_name].desc) # TODO: Getters @@ -80,8 +138,8 @@ class Block(object): return Variable(self, *args, **kwargs) def append_op(self, *args, **kwargs): - op_proto = self.proto.append_op() - op = Operator(self, op_proto, *args, **kwargs) + op_desc = self.proto.append_op() + op = Operator(self, op_desc, *args, **kwargs) self.ops.append(op) return op