From 53f85df1abbafdd248c06c065beebfa2b5d27b29 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 4 Aug 2017 14:16:12 +0800 Subject: [PATCH] Start doing `python.framework.operator` --- ...ate_op_creation_methods.py => operator.py} | 68 ++++++++++++------- 1 file changed, 42 insertions(+), 26 deletions(-) rename python/paddle/v2/framework/{create_op_creation_methods.py => operator.py} (81%) diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/operator.py similarity index 81% rename from python/paddle/v2/framework/create_op_creation_methods.py rename to python/paddle/v2/framework/operator.py index b034efffb..d4c34d7fa 100644 --- a/python/paddle/v2/framework/create_op_creation_methods.py +++ b/python/paddle/v2/framework/operator.py @@ -216,38 +216,54 @@ def create_op_creation_method(op_proto): opdesc = method(*args, **kwargs) return core.Operator.create(opdesc.SerializeToString()) - __impl__.__doc__ = get_docstring_from_op_proto(op_proto) - __impl__.all_input_args = [var.name for var in op_proto.inputs] - __impl__.all_output_args = [var.name for var in op_proto.outputs] - __impl__.all_attr_args = [attr.name for attr in op_proto.attrs] - __impl__.all_not_temp_output_args = [ - var.name for var in op_proto.outputs if not var.temporary - ] + return { + 'method': __impl__, + 'name': op_proto.type, + 'all_inputs': [var.name for var in op_proto.inputs], + 'all_outputs': [var.name for var in op_proto.outputs], + 'all_attrs': [attr.name for attr in op_proto.attrs], + 'all_no_temp_outputs': + [var.name for var in op_proto.outputs if not var.temporary] + } + + +class OperatorFactory(object): + def __init__(self): + self.op_methods = dict() + for op_proto in get_all_op_protos(): + method = create_op_creation_method(op_proto) + self.op_methods[method.name] = method - return __impl__ + def __call__(self, *args, **kwargs): + if 'type' in kwargs: + if len(args) != 0: + raise ValueError("All Paddle argument should be key-word " + "argument except type") + t = kwargs.pop('type') + else: + if len(args) != 1: + raise ValueError("All Paddle argument should be key-word " + "argument except type") + t = args[0] + return self.get_op_creation_info(t)['method'](**kwargs) -class OpCreationsHolder(object): - """ - A object will holds all op creation methods. - - Use `op_creations.xxx_op` to access them. - """ - pass + def get_op_creation_info(self, t): + if t not in self.op_methods: + raise ValueError("operator %s is not registered", t) + return self.op_methods.get(t) + def get_op_input_names(self, type): + return self.get_op_creation_info(type)['all_inputs'] -op_creations = OpCreationsHolder() + def get_op_output_names(self, type): + return self.get_op_creation_info(type)['all_outputs'] + def get_op_attr_names(self, type): + return self.get_op_creation_info(type)['all_attrs'] -def __bootstrap__(): - """ - Bootstrap function for this module. It will dynamic create all op creation - methods in runtime. - """ - for op_proto in get_all_op_protos(): - func = create_op_creation_method(op_proto) - func.__name__ = str(op_proto.type) - setattr(op_creations, func.__name__, func) + def get_op_no_temp_output_names(self, type): + return self.get_op_creation_info(type)['all_no_temp_outputs'] -__bootstrap__() +Operator = OperatorFactory() # Default global factory -- GitLab