提交 53222cb9 编写于 作者: F fengjiayi

Add OpProtoHolder

上级 77150f1f
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import collections
__all__ = ['Block', 'Variable', 'Program', 'Operator']
def get_all_op_protos():
"""
Get all registered op proto from PaddlePaddle C++ end.
:return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
class OpProtoHolder(object):
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
op_protos = get_all_op_protos()
self.op_proto_map = {}
for proto in op_protos:
sefl.op_proto_map[proto.type] = proto
def get_op_proto(self, type):
assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
return self.op_proto_map[type]
class Variable(object):
def __init__(self, block, name=None, shape=None, dtype=None,
lod_level=None):
......@@ -46,14 +81,13 @@ class Operator(object):
self.block = block
self.proto = proto
if type is not None:
# TODO.
pass
self.proto.set_type(type)
if inputs is not None:
# TODO
pass
for k, v in inputs.iteritems():
self.proto.set_input(k, v)
if outputs is not None:
# TODO
pass
for k, v in outputs.iteritems():
self.proto.set_output(k, v)
if attrs is not None:
# TODO
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册