From 53222cb9c3b748b1fc6fcb72f5f122a1c68f50e3 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Oct 2017 17:31:05 -0700 Subject: [PATCH] Add OpProtoHolder --- python/paddle/v2/framework/graph.py | 46 +++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 6f2a76a983..296fab8fed 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -1,9 +1,44 @@ import paddle.v2.framework.core as core +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 import collections __all__ = ['Block', 'Variable', 'Program', 'Operator'] +def get_all_op_protos(): + """ + Get all registered op proto from PaddlePaddle C++ end. + :return: A list of registered OpProto. + """ + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values + + +class OpProtoHolder(object): + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' + op_protos = get_all_op_protos() + self.op_proto_map = {} + for proto in op_protos: + sefl.op_proto_map[proto.type] = proto + + def get_op_proto(self, type): + assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type + return self.op_proto_map[type] + + class Variable(object): def __init__(self, block, name=None, shape=None, dtype=None, lod_level=None): @@ -46,14 +81,13 @@ class Operator(object): self.block = block self.proto = proto if type is not None: - # TODO. - pass + self.proto.set_type(type) if inputs is not None: - # TODO - pass + for k, v in inputs.iteritems(): + self.proto.set_input(k, v) if outputs is not None: - # TODO - pass + for k, v in outputs.iteritems(): + self.proto.set_output(k, v) if attrs is not None: # TODO pass -- GitLab