From afaac7896edbb42cdaed9619727e24917a250272 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 10 Oct 2017 16:57:31 -0700 Subject: [PATCH] Refine code --- python/paddle/v2/framework/graph.py | 68 +++++++++---------- .../v2/framework/tests/test_operator_desc.py | 5 +- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index c53869c88..e4052b150 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -7,40 +7,6 @@ import copy __all__ = ['Block', 'Variable', 'Program', 'Operator'] -def get_all_op_protos(): - """ - Get all registered op proto from PaddlePaddle C++ end. - :return: A list of registered OpProto. - """ - protostrs = core.get_all_op_protos() - ret_values = [] - for pbstr in protostrs: - op_proto = framework_pb2.OpProto.FromString(str(pbstr)) - ret_values.append(op_proto) - return ret_values - - -class OpProtoHolder(object): - @classmethod - def instance(cls): - if not hasattr(cls, '_instance'): - cls._instance = cls() - return cls._instance - - def __init__(self): - assert not hasattr( - self.__class__, - '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' - op_protos = get_all_op_protos() - self.op_proto_map = {} - for proto in op_protos: - self.op_proto_map[proto.type] = proto - - def get_op_proto(self, type): - assert type in self.op_proto_map, "Operator with type \"%s\" has not been registered." % type - return self.op_proto_map[type] - - class Variable(object): def __init__(self, block, @@ -141,6 +107,40 @@ class Variable(object): raise ValueError("Not supported numpy dtype " + str(dtype)) +def get_all_op_protos(): + """ + Get all registered op proto from PaddlePaddle C++ end. + :return: A list of registered OpProto. + """ + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values + + +class OpProtoHolder(object): + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' + op_protos = get_all_op_protos() + self.op_proto_map = {} + for proto in op_protos: + self.op_proto_map[proto.type] = proto + + def get_op_proto(self, type): + assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type + return self.op_proto_map[type] + + class Operator(object): def __init__(self, block, desc, type, inputs=None, outputs=None, attrs=None): diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py index 62f3a05d1..c46c030b2 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -10,9 +10,8 @@ class TestOperator(unittest.TestCase): block.append_op(type="no_such_op") self.assertFail() except AssertionError as err: - self.assertEqual( - err.message, - "Operator with type \"no_such_op\" has not been registered.") + self.assertEqual(err.message, + "Operator \"no_such_op\" has not been registered.") def test_op_desc_creation(self): block = g_program.current_block() -- GitLab