diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index a3dbd0cc8939489e864769b0fef40d129ba78800..81c8c3fed844614fbb22dccca685fa84c0683dba 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -145,6 +145,16 @@ class OpDescCreationMethod(object): return False +class OpInfo(object): + def __init__(self, name, method, inputs, outputs, attrs, no_temp_outputs): + self.name = name + self.method = method + self.inputs = inputs + self.outputs = outputs + self.attrs = attrs + self.no_temp_outputs = no_temp_outputs + + def create_op_creation_method(op_proto): """ Generate op creation method for an OpProto @@ -155,15 +165,15 @@ def create_op_creation_method(op_proto): opdesc = method(*args, **kwargs) return core.Operator.create(opdesc.SerializeToString()) - return { - 'method': __impl__, - 'name': op_proto.type, - 'all_inputs': [var.name for var in op_proto.inputs], - 'all_outputs': [var.name for var in op_proto.outputs], - 'all_attrs': [attr.name for attr in op_proto.attrs], - 'all_no_temp_outputs': - [var.name for var in op_proto.outputs if not var.temporary] - } + return OpInfo( + method=__impl__, + name=op_proto.type, + inputs=[var.name for var in op_proto.inputs], + outputs=[var.name for var in op_proto.outputs], + attrs=[attr.name for attr in op_proto.attrs], + no_temp_outputs=[ + var.name for var in op_proto.outputs if not var.temporary + ]) class OperatorFactory(object): @@ -185,27 +195,27 @@ class OperatorFactory(object): "argument except type") t = args[0] - return self.get_op_creation_info(t)['method'](**kwargs) + return self.get_op_info(t).method(**kwargs) def types(self): return self.op_methods.keys() - def get_op_creation_info(self, t): + def get_op_info(self, t): if t not in self.op_methods: raise ValueError("operator %s is not registered", t) return self.op_methods.get(t) def get_op_input_names(self, type): - return self.get_op_creation_info(type)['all_inputs'] + return self.get_op_info(type).inputs def get_op_output_names(self, type): - return self.get_op_creation_info(type)['all_outputs'] + return self.get_op_info(type).outputs def get_op_attr_names(self, type): - return self.get_op_creation_info(type)['all_attrs'] + return self.get_op_info(type).attrs def get_op_no_temp_output_names(self, type): - return self.get_op_creation_info(type)['all_no_temp_outputs'] + return self.get_op_info(type).no_temp_outputs Operator = OperatorFactory() # Default global factory diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 2c648d22f35a1492b0a6848825888e41075b30d4..d01e005aca6d411e31d19d0f0cf5f4db097ec15b 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -20,4 +20,4 @@ py_test(gradient_checker SRCS gradient_checker.py) py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) -py_test(test_operator SRCS test_operator.py +py_test(test_operator SRCS test_operator.py)