提交 53f85df1 编写于 作者: Y Yu Yang

Start doing `python.framework.operator`

上级 636d46a1
...@@ -216,38 +216,54 @@ def create_op_creation_method(op_proto): ...@@ -216,38 +216,54 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs) opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString()) return core.Operator.create(opdesc.SerializeToString())
__impl__.__doc__ = get_docstring_from_op_proto(op_proto) return {
__impl__.all_input_args = [var.name for var in op_proto.inputs] 'method': __impl__,
__impl__.all_output_args = [var.name for var in op_proto.outputs] 'name': op_proto.type,
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs] 'all_inputs': [var.name for var in op_proto.inputs],
__impl__.all_not_temp_output_args = [ 'all_outputs': [var.name for var in op_proto.outputs],
var.name for var in op_proto.outputs if not var.temporary 'all_attrs': [attr.name for attr in op_proto.attrs],
] 'all_no_temp_outputs':
[var.name for var in op_proto.outputs if not var.temporary]
return __impl__ }
class OperatorFactory(object):
def __init__(self):
self.op_methods = dict()
for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method
def __call__(self, *args, **kwargs):
if 'type' in kwargs:
if len(args) != 0:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
t = kwargs.pop('type')
else:
if len(args) != 1:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
t = args[0]
class OpCreationsHolder(object): return self.get_op_creation_info(t)['method'](**kwargs)
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them. def get_op_creation_info(self, t):
""" if t not in self.op_methods:
pass raise ValueError("operator %s is not registered", t)
return self.op_methods.get(t)
def get_op_input_names(self, type):
return self.get_op_creation_info(type)['all_inputs']
op_creations = OpCreationsHolder() def get_op_output_names(self, type):
return self.get_op_creation_info(type)['all_outputs']
def get_op_attr_names(self, type):
return self.get_op_creation_info(type)['all_attrs']
def __bootstrap__(): def get_op_no_temp_output_names(self, type):
""" return self.get_op_creation_info(type)['all_no_temp_outputs']
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for op_proto in get_all_op_protos():
func = create_op_creation_method(op_proto)
func.__name__ = str(op_proto.type)
setattr(op_creations, func.__name__, func)
__bootstrap__() Operator = OperatorFactory() # Default global factory
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册