提交 9f816352 编写于 作者: Y Yu Yang

Follow comments

上级 5d074c91
...@@ -145,6 +145,16 @@ class OpDescCreationMethod(object): ...@@ -145,6 +145,16 @@ class OpDescCreationMethod(object):
return False return False
class OpInfo(object):
def __init__(self, name, method, inputs, outputs, attrs, no_temp_outputs):
self.name = name
self.method = method
self.inputs = inputs
self.outputs = outputs
self.attrs = attrs
self.no_temp_outputs = no_temp_outputs
def create_op_creation_method(op_proto): def create_op_creation_method(op_proto):
""" """
Generate op creation method for an OpProto Generate op creation method for an OpProto
...@@ -155,15 +165,15 @@ def create_op_creation_method(op_proto): ...@@ -155,15 +165,15 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs) opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString()) return core.Operator.create(opdesc.SerializeToString())
return { return OpInfo(
'method': __impl__, method=__impl__,
'name': op_proto.type, name=op_proto.type,
'all_inputs': [var.name for var in op_proto.inputs], inputs=[var.name for var in op_proto.inputs],
'all_outputs': [var.name for var in op_proto.outputs], outputs=[var.name for var in op_proto.outputs],
'all_attrs': [attr.name for attr in op_proto.attrs], attrs=[attr.name for attr in op_proto.attrs],
'all_no_temp_outputs': no_temp_outputs=[
[var.name for var in op_proto.outputs if not var.temporary] var.name for var in op_proto.outputs if not var.temporary
} ])
class OperatorFactory(object): class OperatorFactory(object):
...@@ -185,27 +195,27 @@ class OperatorFactory(object): ...@@ -185,27 +195,27 @@ class OperatorFactory(object):
"argument except type") "argument except type")
t = args[0] t = args[0]
return self.get_op_creation_info(t)['method'](**kwargs) return self.get_op_info(t).method(**kwargs)
def types(self): def types(self):
return self.op_methods.keys() return self.op_methods.keys()
def get_op_creation_info(self, t): def get_op_info(self, t):
if t not in self.op_methods: if t not in self.op_methods:
raise ValueError("operator %s is not registered", t) raise ValueError("operator %s is not registered", t)
return self.op_methods.get(t) return self.op_methods.get(t)
def get_op_input_names(self, type): def get_op_input_names(self, type):
return self.get_op_creation_info(type)['all_inputs'] return self.get_op_info(type).inputs
def get_op_output_names(self, type): def get_op_output_names(self, type):
return self.get_op_creation_info(type)['all_outputs'] return self.get_op_info(type).outputs
def get_op_attr_names(self, type): def get_op_attr_names(self, type):
return self.get_op_creation_info(type)['all_attrs'] return self.get_op_info(type).attrs
def get_op_no_temp_output_names(self, type): def get_op_no_temp_output_names(self, type):
return self.get_op_creation_info(type)['all_no_temp_outputs'] return self.get_op_info(type).no_temp_outputs
Operator = OperatorFactory() # Default global factory Operator = OperatorFactory() # Default global factory
...@@ -20,4 +20,4 @@ py_test(gradient_checker SRCS gradient_checker.py) ...@@ -20,4 +20,4 @@ py_test(gradient_checker SRCS gradient_checker.py)
py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py py_test(test_operator SRCS test_operator.py)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册