提交 9f816352 编写于 作者: Y Yu Yang

Follow comments

上级 5d074c91
......@@ -145,6 +145,16 @@ class OpDescCreationMethod(object):
return False
class OpInfo(object):
def __init__(self, name, method, inputs, outputs, attrs, no_temp_outputs):
self.name = name
self.method = method
self.inputs = inputs
self.outputs = outputs
self.attrs = attrs
self.no_temp_outputs = no_temp_outputs
def create_op_creation_method(op_proto):
"""
Generate op creation method for an OpProto
......@@ -155,15 +165,15 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString())
return {
'method': __impl__,
'name': op_proto.type,
'all_inputs': [var.name for var in op_proto.inputs],
'all_outputs': [var.name for var in op_proto.outputs],
'all_attrs': [attr.name for attr in op_proto.attrs],
'all_no_temp_outputs':
[var.name for var in op_proto.outputs if not var.temporary]
}
return OpInfo(
method=__impl__,
name=op_proto.type,
inputs=[var.name for var in op_proto.inputs],
outputs=[var.name for var in op_proto.outputs],
attrs=[attr.name for attr in op_proto.attrs],
no_temp_outputs=[
var.name for var in op_proto.outputs if not var.temporary
])
class OperatorFactory(object):
......@@ -185,27 +195,27 @@ class OperatorFactory(object):
"argument except type")
t = args[0]
return self.get_op_creation_info(t)['method'](**kwargs)
return self.get_op_info(t).method(**kwargs)
def types(self):
return self.op_methods.keys()
def get_op_creation_info(self, t):
def get_op_info(self, t):
if t not in self.op_methods:
raise ValueError("operator %s is not registered", t)
return self.op_methods.get(t)
def get_op_input_names(self, type):
return self.get_op_creation_info(type)['all_inputs']
return self.get_op_info(type).inputs
def get_op_output_names(self, type):
return self.get_op_creation_info(type)['all_outputs']
return self.get_op_info(type).outputs
def get_op_attr_names(self, type):
return self.get_op_creation_info(type)['all_attrs']
return self.get_op_info(type).attrs
def get_op_no_temp_output_names(self, type):
return self.get_op_creation_info(type)['all_no_temp_outputs']
return self.get_op_info(type).no_temp_outputs
Operator = OperatorFactory() # Default global factory
......@@ -20,4 +20,4 @@ py_test(gradient_checker SRCS gradient_checker.py)
py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py
py_test(test_operator SRCS test_operator.py)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册