提交 5a7a3c4d 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3237 from reyoung/feature/change_op_creation

Feature/change op creation
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import new_var, find_var, get_cur_scope
__all__ = ['Network'] # Only expose Network
class NetworkFunctor(object):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def __init__(self, func, net):
self.func = func
self.net = net
def __call__(self, *args, **kwargs):
if len(args) != 0:
raise ValueError("Paddle must use keyword argument")
inputs = self.func.all_input_args
for ipt in inputs:
if ipt in kwargs:
var = kwargs[ipt]
if isinstance(var, basestring):
tmp = new_var(var)
self.net.var_names[tmp] = var
var = tmp
if not isinstance(var, core.Variable):
raise TypeError(
"Input of op creation must be string or variable")
kwargs[ipt] = self.net.var_names[var]
notemp_outputs = self.func.all_not_temp_output_args
for name in notemp_outputs:
if name not in kwargs:
kwargs[
name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
)
outputs = self.func.all_output_args
for opt in outputs:
if opt in kwargs:
var = kwargs[opt]
if isinstance(var, basestring):
tmp = new_var(var)
self.net.var_names[tmp] = var
var = tmp
if not isinstance(var, core.Variable):
raise TypeError(
"Output of op creation must be string or variable")
kwargs[opt] = self.net.var_names[var]
op = self.func(**kwargs)
self.net.net.add_op(op)
lst = [find_var(kwargs[opt]) for opt in notemp_outputs]
if len(lst) == 1:
return lst[0]
elif len(lst) == 0:
return None
else:
return lst
class Network(object):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def __init__(self):
self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__"))
self.var_names = dict()
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for func_name in funcs:
func = getattr(op_creations, func_name)
impl = NetworkFunctor(func, self)
setattr(self, func_name, impl.__call__)
self.__complete_add_op__ = False
def infer_shape(self):
self.complete_add_op()
self.net.infer_shape(get_cur_scope())
def run(self, device_context):
self.complete_add_op()
self.net.run(get_cur_scope(), device_context)
def __str__(self):
return str(self.net)
def complete_add_op(self):
if not self.__complete_add_op__:
self.net.complete_add_op()
self.__complete_add_op__ = True
if __name__ == '__main__':
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
net.complete_add_op()
print net
...@@ -2,7 +2,6 @@ import paddle.v2.framework.core as core ...@@ -2,7 +2,6 @@ import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2 import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
import cStringIO
def get_all_op_protos(): def get_all_op_protos():
...@@ -146,64 +145,14 @@ class OpDescCreationMethod(object): ...@@ -146,64 +145,14 @@ class OpDescCreationMethod(object):
return False return False
def get_docstring_from_op_proto(op_proto): class OpInfo(object):
""" def __init__(self, name, method, inputs, outputs, attrs, no_temp_outputs):
Generate docstring from a OpProto self.name = name
:param op_proto: a OpProto instance. self.method = method
:type op_proto: op_proto_pb2.OpProto self.inputs = inputs
:return: docstring self.outputs = outputs
""" self.attrs = attrs
if not isinstance(op_proto, op_proto_pb2.OpProto): self.no_temp_outputs = no_temp_outputs
raise TypeError("Input must be OpProto")
f = cStringIO.StringIO()
f.write(op_proto.comment)
f.write("\n")
def __append_param__(name, comment, type):
# Maybe replace the following line with template engine is better.
f.write(":param ")
f.write(name)
f.write(": ")
f.write(comment)
f.write("\n")
f.write(":type ")
f.write(name)
f.write(": ")
f.write(type)
f.write("\n")
for ipt in op_proto.inputs:
__append_param__(ipt.name, ipt.comment, "list | basestr"
if ipt.multiple else "basestr")
temp_var_prefix = \
"This is a temporary variable. It does not have to set by user. "
for opt in op_proto.outputs:
__append_param__(opt.name, opt.comment if not opt.temporary else
temp_var_prefix + opt.comment, "list | basestr"
if opt.multiple else "basestr")
for attr in op_proto.attrs:
attr_type = None
if attr.type == attribute_pb2.INT:
attr_type = "int"
elif attr.type == attribute_pb2.FLOAT:
attr_type = "float"
elif attr.type == attribute_pb2.STRING:
attr_type = "basestr"
elif attr.type == attribute_pb2.INTS:
attr_type = "list of int"
elif attr.type == attribute_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attribute_pb2.STRINGS:
attr_type = "list of basestr"
if attr_type is None:
raise RuntimeError("Not supported attribute type " + attr.type)
__append_param__(attr.name, attr.comment, attr_type)
return f.getvalue()
def create_op_creation_method(op_proto): def create_op_creation_method(op_proto):
...@@ -216,38 +165,57 @@ def create_op_creation_method(op_proto): ...@@ -216,38 +165,57 @@ def create_op_creation_method(op_proto):
opdesc = method(*args, **kwargs) opdesc = method(*args, **kwargs)
return core.Operator.create(opdesc.SerializeToString()) return core.Operator.create(opdesc.SerializeToString())
__impl__.__doc__ = get_docstring_from_op_proto(op_proto) return OpInfo(
__impl__.all_input_args = [var.name for var in op_proto.inputs] method=__impl__,
__impl__.all_output_args = [var.name for var in op_proto.outputs] name=op_proto.type,
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs] inputs=[var.name for var in op_proto.inputs],
__impl__.all_not_temp_output_args = [ outputs=[var.name for var in op_proto.outputs],
var.name for var in op_proto.outputs if not var.temporary attrs=[attr.name for attr in op_proto.attrs],
] no_temp_outputs=[
var.name for var in op_proto.outputs if not var.temporary
])
return __impl__
class OperatorFactory(object):
def __init__(self):
self.op_methods = dict()
for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method
class OpCreationsHolder(object): def __call__(self, *args, **kwargs):
""" if 'type' in kwargs:
A object will holds all op creation methods. if len(args) != 0:
raise ValueError("All Paddle argument should be key-word "
Use `op_creations.xxx_op` to access them. "argument except type")
""" t = kwargs.pop('type')
pass else:
if len(args) != 1:
raise ValueError("All Paddle argument should be key-word "
"argument except type")
t = args[0]
return self.get_op_info(t).method(**kwargs)
op_creations = OpCreationsHolder() def types(self):
return self.op_methods.keys()
def get_op_info(self, t):
if t not in self.op_methods:
raise ValueError("operator %s is not registered", t)
return self.op_methods.get(t)
def __bootstrap__(): def get_op_input_names(self, type):
""" return self.get_op_info(type).inputs
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime. def get_op_output_names(self, type):
""" return self.get_op_info(type).outputs
for op_proto in get_all_op_protos():
func = create_op_creation_method(op_proto) def get_op_attr_names(self, type):
func.__name__ = str(op_proto.type) return self.get_op_info(type).attrs
setattr(op_creations, func.__name__, func)
def get_op_no_temp_output_names(self, type):
return self.get_op_info(type).no_temp_outputs
__bootstrap__() Operator = OperatorFactory() # Default global factory
...@@ -6,7 +6,6 @@ py_test(test_scope SRCS test_scope.py) ...@@ -6,7 +6,6 @@ py_test(test_scope SRCS test_scope.py)
py_test(test_tensor SRCS test_tensor.py) py_test(test_tensor SRCS test_tensor.py)
py_test(test_mul_op SRCS test_mul_op.py) py_test(test_mul_op SRCS test_mul_op.py)
py_test(test_network SRCS test_network.py)
py_test(test_mean_op SRCS test_mean_op.py) py_test(test_mean_op SRCS test_mean_op.py)
py_test(test_protobuf SRCS test_protobuf.py) py_test(test_protobuf SRCS test_protobuf.py)
...@@ -20,4 +19,4 @@ py_test(gradient_checker SRCS gradient_checker.py) ...@@ -20,4 +19,4 @@ py_test(gradient_checker SRCS gradient_checker.py)
py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_op_creation_methods SRCS test_op_creation_methods.py) py_test(test_operator SRCS test_operator.py)
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations from paddle.v2.framework.op import Operator
import numpy import numpy
import unittest import unittest
...@@ -80,7 +80,7 @@ if __name__ == '__main__': ...@@ -80,7 +80,7 @@ if __name__ == '__main__':
class GetNumericGradientTest(unittest.TestCase): class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self): def test_add_op(self):
add_op = op_creations.add_two(X="X", Y="Y", Out="Z") add_op = Operator('add_two', X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32") x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32") y = numpy.random.random((10, 1)).astype("float32")
......
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import unittest import unittest
import numpy import numpy
import paddle.v2.framework.create_op_creation_methods as creation from paddle.v2.framework.op import Operator
class OpTestMeta(type): class OpTestMeta(type):
...@@ -21,18 +21,14 @@ class OpTestMeta(type): ...@@ -21,18 +21,14 @@ class OpTestMeta(type):
obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs) obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs)
def test_all(self): def test_all(self):
func = getattr(creation.op_creations, self.type, None)
self.assertIsNotNone(func)
scope = core.Scope() scope = core.Scope()
kwargs = dict() kwargs = dict()
places = [] places = [core.CPUPlace()]
places.append(core.CPUPlace())
if core.is_compile_gpu() and core.Operator.support_gpu(self.type): if core.is_compile_gpu() and core.Operator.support_gpu(self.type):
places.append(core.GPUPlace(0)) places.append(core.GPUPlace(0))
for place in places: for place in places:
for in_name in func.all_input_args: for in_name in Operator.get_op_input_names(self.type):
if hasattr(self, "inputs") and in_name in self.inputs: if hasattr(self, "inputs") and in_name in self.inputs:
kwargs[in_name] = in_name kwargs[in_name] = in_name
var = scope.new_var(in_name).get_tensor() var = scope.new_var(in_name).get_tensor()
...@@ -42,7 +38,7 @@ class OpTestMeta(type): ...@@ -42,7 +38,7 @@ class OpTestMeta(type):
else: else:
kwargs[in_name] = "@EMPTY@" kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args: for out_name in Operator.get_op_output_names(self.type):
if not hasattr(self, "outputs"): if not hasattr(self, "outputs"):
raise ValueError( raise ValueError(
"The test op must set self.outputs dict.") "The test op must set self.outputs dict.")
...@@ -52,18 +48,18 @@ class OpTestMeta(type): ...@@ -52,18 +48,18 @@ class OpTestMeta(type):
kwargs[out_name] = out_name kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor() scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args: for attr_name in Operator.get_op_attr_names(self.type):
if hasattr(self, "attrs") and attr_name in self.attrs: if hasattr(self, "attrs") and attr_name in self.attrs:
kwargs[attr_name] = self.attrs[attr_name] kwargs[attr_name] = self.attrs[attr_name]
op = func(**kwargs) op = Operator(self.type, **kwargs)
op.infer_shape(scope) op.infer_shape(scope)
ctx = core.DeviceContext.create(place) ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
for out_name in func.all_output_args: for out_name in Operator.get_op_output_names(self.type):
actual = numpy.array(scope.find_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name] expect = self.outputs[out_name]
self.assertTrue( self.assertTrue(
......
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
import numpy import numpy
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation from paddle.v2.framework.op import Operator
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
...@@ -21,7 +21,7 @@ class TestAddOp(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestAddOp(unittest.TestCase):
class TestAddGradOp(unittest.TestCase): class TestAddGradOp(unittest.TestCase):
def test_add_grad(self): def test_add_grad(self):
op = creation.op_creations.add_two(X="X", Y="Y", Out="Out") op = Operator('add_two', X="X", Y="Y", Out="Out")
backward_op = core.Operator.backward(op, set()) backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "add_two_grad") self.assertEqual(backward_op.type(), "add_two_grad")
expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).''' expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
......
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import unittest import unittest
import numpy import numpy
import paddle.v2.framework.create_op_creation_methods as creation from paddle.v2.framework.op import Operator
class TestFc(unittest.TestCase): class TestFc(unittest.TestCase):
...@@ -24,7 +24,7 @@ class TestFc(unittest.TestCase): ...@@ -24,7 +24,7 @@ class TestFc(unittest.TestCase):
# Set a real numpy array here. # Set a real numpy array here.
# x_tensor.set(numpy.array([])) # x_tensor.set(numpy.array([]))
op = creation.op_creations.fc(X="X", Y="Y", W="W") op = Operator("fc", X="X", Y="Y", W="W")
for out in op.outputs(): for out in op.outputs():
if scope.find_var(out) is None: if scope.find_var(out) is None:
......
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations from paddle.v2.framework.op import Operator
import unittest import unittest
class TestNet(unittest.TestCase): class TestNet(unittest.TestCase):
def test_net_all(self): def test_net_all(self):
net = core.Net.create() net = core.Net.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out") op1 = Operator("add_two", X="X", Y="Y", Out="Out")
net.add_op(op1) net.add_op(op1)
net2 = core.Net.create() net2 = core.Net.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out")) net2.add_op(Operator("fc", X="X", W="w", Y="fc.out"))
net2.complete_add_op(True) net2.complete_add_op(True)
net.add_op(net2) net.add_op(net2)
net.complete_add_op(True) net.complete_add_op(True)
......
from paddle.v2.framework.network import Network
import paddle.v2.framework.core as core
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = Network()
out = net.add_two(X="X", Y="Y")
fc_out = net.fc(X=out, W="w")
net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual(
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net))
net2 = Network()
tmp = net2.add_two(X="X", Y="Y")
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
if __name__ == '__main__':
unittest.main()
import unittest import unittest
import paddle.v2.framework.create_op_creation_methods as creation import paddle.v2.framework.op as op
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
...@@ -8,7 +8,7 @@ import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2 ...@@ -8,7 +8,7 @@ import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
class TestGetAllProtos(unittest.TestCase): class TestGetAllProtos(unittest.TestCase):
def test_all(self): def test_all(self):
all_protos = creation.get_all_op_protos() all_protos = op.get_all_op_protos()
self.assertNotEqual(0, len(all_protos)) self.assertNotEqual(0, len(all_protos))
for each in all_protos: for each in all_protos:
...@@ -17,25 +17,25 @@ class TestGetAllProtos(unittest.TestCase): ...@@ -17,25 +17,25 @@ class TestGetAllProtos(unittest.TestCase):
class TestOpDescCreationMethod(unittest.TestCase): class TestOpDescCreationMethod(unittest.TestCase):
def test_plain_input_output(self): def test_plain_input_output(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "test" op_proto.type = "test"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "X" ipt.name = "X"
ipt.comment = "not matter" ipt.comment = "not matter"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "Y" ipt.name = "Y"
ipt.comment = "not matter" ipt.comment = "not matter"
opt = op.outputs.add() opt = op_proto.outputs.add()
opt.name = "Z" opt.name = "Z"
opt.comment = "not matter" opt.comment = "not matter"
op.comment = "not matter" op_proto.comment = "not matter"
self.assertTrue(op.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
output = method(X="a", Y="b", Z="c") output = method(X="a", Y="b", Z="c")
expected = op_desc_pb2.OpDesc() expected = op_desc_pb2.OpDesc()
...@@ -45,29 +45,29 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -45,29 +45,29 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(expected, output) self.assertEqual(expected, output)
def test_multiple_input_plain_output(self): def test_multiple_input_plain_output(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "fc" op_proto.type = "fc"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "X" ipt.name = "X"
ipt.comment = "" ipt.comment = ""
ipt.multiple = True ipt.multiple = True
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "W" ipt.name = "W"
ipt.comment = "" ipt.comment = ""
ipt.multiple = True ipt.multiple = True
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = "b" ipt.name = "b"
ipt.comment = "" ipt.comment = ""
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "Y" out.name = "Y"
out.comment = "" out.comment = ""
op.comment = "" op_proto.comment = ""
self.assertTrue(op.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
generated1 = method(X="x", W="w", b="b", Y="y") generated1 = method(X="x", W="w", b="b", Y="y")
expected1 = op_desc_pb2.OpDesc() expected1 = op_desc_pb2.OpDesc()
...@@ -93,14 +93,14 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -93,14 +93,14 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(expected2, generated2) self.assertEqual(expected2, generated2)
def test_attrs(self): def test_attrs(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "test" op_proto.type = "test"
ipt = op.inputs.add() ipt = op_proto.inputs.add()
ipt.name = 'X' ipt.name = 'X'
ipt.comment = "" ipt.comment = ""
def __add_attr__(name, type): def __add_attr__(name, type):
attr = op.attrs.add() attr = op_proto.attrs.add()
attr.name = name attr.name = name
attr.comment = "" attr.comment = ""
attr.type = type attr.type = type
...@@ -112,10 +112,10 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -112,10 +112,10 @@ class TestOpDescCreationMethod(unittest.TestCase):
__add_attr__("floats_attr", attribute_pb2.FLOATS) __add_attr__("floats_attr", attribute_pb2.FLOATS)
__add_attr__("strings_attr", attribute_pb2.STRINGS) __add_attr__("strings_attr", attribute_pb2.STRINGS)
op.comment = "" op_proto.comment = ""
self.assertTrue(op.IsInitialized()) self.assertTrue(op_proto.IsInitialized())
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
generated = method( generated = method(
X="a", X="a",
...@@ -162,23 +162,23 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -162,23 +162,23 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(expected, generated) self.assertEqual(expected, generated)
def test_input_temporary_output(self): def test_input_temporary_output(self):
op = op_proto_pb2.OpProto() op_proto = op_proto_pb2.OpProto()
op.type = "test" op_proto.type = "test"
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "OUT" out.name = "OUT"
out.comment = "" out.comment = ""
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "TMP" out.name = "TMP"
out.comment = "" out.comment = ""
out.temporary = True out.temporary = True
out = op.outputs.add() out = op_proto.outputs.add()
out.name = "OUT2" out.name = "OUT2"
out.comment = "" out.comment = ""
op.comment = "" op_proto.comment = ""
method = creation.OpDescCreationMethod(op) method = op.OpDescCreationMethod(op_proto)
generated = method(OUT="a", OUT2="b") generated = method(OUT="a", OUT2="b")
desc = op_desc_pb2.OpDesc() desc = op_desc_pb2.OpDesc()
desc.outputs.extend(["a", core.var_names.temp(), "b"]) desc.outputs.extend(["a", core.var_names.temp(), "b"])
...@@ -190,60 +190,9 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -190,60 +190,9 @@ class TestOpDescCreationMethod(unittest.TestCase):
self.assertEqual(generated, desc) self.assertEqual(generated, desc)
class TestOpCreationDocStr(unittest.TestCase):
def test_all(self):
op = op_proto_pb2.OpProto()
op.type = "test"
op.comment = """Test Op.
This op is used for unit test, not a real op.
"""
a = op.inputs.add()
a.name = "a"
a.comment = "Input a for test op"
a.multiple = True
b = op.inputs.add()
b.name = "b"
b.comment = "Input b for test op"
self.assertTrue(op.IsInitialized())
o1 = op.outputs.add()
o1.name = "output"
o1.comment = "The output of test op"
o2 = op.outputs.add()
o2.name = "temp output"
o2.comment = "The temporary output of test op"
o2.temporary = True
test_str = op.attrs.add()
test_str.name = "str_attr"
test_str.type = attribute_pb2.STRING
test_str.comment = "A string attribute for test op"
actual = creation.get_docstring_from_op_proto(op)
expected_docstring = '''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self.assertEqual(expected_docstring, actual)
class TestOpCreations(unittest.TestCase): class TestOpCreations(unittest.TestCase):
def test_all(self): def test_all(self):
add_op = creation.op_creations.add_two(X="a", Y="b", Out="z") add_op = op.Operator("add_two", X="a", Y="b", Out="z")
self.assertIsNotNone(add_op) self.assertIsNotNone(add_op)
# Invoke C++ DebugString() # Invoke C++ DebugString()
self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).', self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
......
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
import numpy as np import numpy as np
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation from paddle.v2.framework.op import Operator
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
...@@ -27,7 +27,7 @@ class TestSoftmaxOp(unittest.TestCase): ...@@ -27,7 +27,7 @@ class TestSoftmaxOp(unittest.TestCase):
class TestSoftmaxGradOp(unittest.TestCase): class TestSoftmaxGradOp(unittest.TestCase):
def test_softmax_grad(self): def test_softmax_grad(self):
op = creation.op_creations.softmax(X="X", Y="Y") op = Operator('softmax', X="X", Y="Y")
backward_op = core.Operator.backward(op, set()) backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "softmax_grad") self.assertEqual(backward_op.type(), "softmax_grad")
expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).''' expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册