提交 53f85df1 编写于 作者: Y Yu Yang

Start doing `python.framework.operator`

上级 636d46a1
......@@ -216,38 +216,54 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString())
__impl__.__doc__ = get_docstring_from_op_proto(op_proto)
__impl__.all_input_args = [var.name for var in op_proto.inputs]
__impl__.all_output_args = [var.name for var in op_proto.outputs]
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs]
__impl__.all_not_temp_output_args = [
var.name for var in op_proto.outputs if not var.temporary
]
return {
'method': __impl__,
'name': op_proto.type,
'all_inputs': [var.name for var in op_proto.inputs],
'all_outputs': [var.name for var in op_proto.outputs],
'all_attrs': [attr.name for attr in op_proto.attrs],
'all_no_temp_outputs':
[var.name for var in op_proto.outputs if not var.temporary]
}
class OperatorFactory(object):
def __init__(self):
self.op_methods = dict()
for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method
return __impl__
def __call__(self, *args, **kwargs):
if 'type' in kwargs:
if len(args) != 0:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
t = kwargs.pop('type')
else:
if len(args) != 1:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
t = args[0]
return self.get_op_creation_info(t)['method'](**kwargs)
class OpCreationsHolder(object):
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them.
"""
pass
def get_op_creation_info(self, t):
if t not in self.op_methods:
raise ValueError("operator %s is not registered", t)
return self.op_methods.get(t)
def get_op_input_names(self, type):
return self.get_op_creation_info(type)['all_inputs']
op_creations = OpCreationsHolder()
def get_op_output_names(self, type):
return self.get_op_creation_info(type)['all_outputs']
def get_op_attr_names(self, type):
return self.get_op_creation_info(type)['all_attrs']
def __bootstrap__():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for op_proto in get_all_op_protos():
func = create_op_creation_method(op_proto)
func.__name__ = str(op_proto.type)
setattr(op_creations, func.__name__, func)
def get_op_no_temp_output_names(self, type):
return self.get_op_creation_info(type)['all_no_temp_outputs']
__bootstrap__()
Operator = OperatorFactory() # Default global factory
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册