提交 afaac789 编写于 作者: F fengjiayi

Refine code

上级 906f5e8a
...@@ -7,40 +7,6 @@ import copy ...@@ -7,40 +7,6 @@ import copy
__all__ = ['Block', 'Variable', 'Program', 'Operator'] __all__ = ['Block', 'Variable', 'Program', 'Operator']
def get_all_op_protos():
"""
Get all registered op proto from PaddlePaddle C++ end.
:return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
class OpProtoHolder(object):
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
op_protos = get_all_op_protos()
self.op_proto_map = {}
for proto in op_protos:
self.op_proto_map[proto.type] = proto
def get_op_proto(self, type):
assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type
return self.op_proto_map[type]
class Variable(object): class Variable(object):
def __init__(self, def __init__(self,
block, block,
...@@ -141,6 +107,40 @@ class Variable(object): ...@@ -141,6 +107,40 @@ class Variable(object):
raise ValueError("Not supported numpy dtype " + str(dtype)) raise ValueError("Not supported numpy dtype " + str(dtype))
def get_all_op_protos():
"""
Get all registered op proto from PaddlePaddle C++ end.
:return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
class OpProtoHolder(object):
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
op_protos = get_all_op_protos()
self.op_proto_map = {}
for proto in op_protos:
self.op_proto_map[proto.type] = proto
def get_op_proto(self, type):
assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type
return self.op_proto_map[type]
class Operator(object): class Operator(object):
def __init__(self, block, desc, type, inputs=None, outputs=None, def __init__(self, block, desc, type, inputs=None, outputs=None,
attrs=None): attrs=None):
......
...@@ -10,9 +10,8 @@ class TestOperator(unittest.TestCase): ...@@ -10,9 +10,8 @@ class TestOperator(unittest.TestCase):
block.append_op(type="no_such_op") block.append_op(type="no_such_op")
self.assertFail() self.assertFail()
except AssertionError as err: except AssertionError as err:
self.assertEqual( self.assertEqual(err.message,
err.message, "Operator \"no_such_op\" has not been registered.")
"Operator with type \"no_such_op\" has not been registered.")
def test_op_desc_creation(self): def test_op_desc_creation(self):
block = g_program.current_block() block = g_program.current_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册