提交 89d33ff8 编写于 作者: Y Yu Yang

Complete chagne op creation method.

Currently use `Operator("fc", X="x", W='w1', B='b1')` as operator
creation method.

Fix #3198
上级 82f8304f
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import new_var, find_var, get_cur_scope
__all__ = ['Network'] # Only expose Network
class NetworkFunctor(object):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def __init__(self, func, net):
self.func = func
self.net = net
def __call__(self, *args, **kwargs):
if len(args) != 0:
raise ValueError("Paddle must use keyword argument")
inputs = self.func.all_input_args
for ipt in inputs:
if ipt in kwargs:
var = kwargs[ipt]
if isinstance(var, basestring):
tmp = new_var(var)
self.net.var_names[tmp] = var
var = tmp
if not isinstance(var, core.Variable):
raise TypeError(
"Input of op creation must be string or variable")
kwargs[ipt] = self.net.var_names[var]
notemp_outputs = self.func.all_not_temp_output_args
for name in notemp_outputs:
if name not in kwargs:
kwargs[
name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
)
outputs = self.func.all_output_args
for opt in outputs:
if opt in kwargs:
var = kwargs[opt]
if isinstance(var, basestring):
tmp = new_var(var)
self.net.var_names[tmp] = var
var = tmp
if not isinstance(var, core.Variable):
raise TypeError(
"Output of op creation must be string or variable")
kwargs[opt] = self.net.var_names[var]
op = self.func(**kwargs)
self.net.net.add_op(op)
lst = [find_var(kwargs[opt]) for opt in notemp_outputs]
if len(lst) == 1:
return lst[0]
elif len(lst) == 0:
return None
else:
return lst
class Network(object):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def __init__(self):
self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__"))
self.var_names = dict()
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for func_name in funcs:
func = getattr(op_creations, func_name)
impl = NetworkFunctor(func, self)
setattr(self, func_name, impl.__call__)
self.__complete_add_op__ = False
def infer_shape(self):
self.complete_add_op()
self.net.infer_shape(get_cur_scope())
def run(self, device_context):
self.complete_add_op()
self.net.run(get_cur_scope(), device_context)
def __str__(self):
return str(self.net)
def complete_add_op(self):
if not self.__complete_add_op__:
self.net.complete_add_op()
self.__complete_add_op__ = True
if __name__ == '__main__':
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
net.complete_add_op()
print net
......@@ -187,6 +187,9 @@ class OperatorFactory(object):
return self.get_op_creation_info(t)['method'](**kwargs)
def types(self):
return self.op_methods.keys()
def get_op_creation_info(self, t):
if t not in self.op_methods:
raise ValueError("operator %s is not registered", t)
......
add_python_test(test_framework
test_protobuf.py
test_scope.py
test_operator.py
test_default_scope_funcs.py
test_op_creation_methods.py
test_net.py
test_tensor.py
test_fc_op.py
......@@ -13,5 +13,5 @@ add_python_test(test_framework
test_sigmoid_op.py
test_softmax_op.py
test_rowwise_add_op.py
test_network.py
gradient_checker.py)
gradient_checker.py
)
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from paddle.v2.framework.op import Operator
import numpy
import unittest
......@@ -80,7 +80,7 @@ if __name__ == '__main__':
class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
add_op = op_creations.add_two(X="X", Y="Y", Out="Z")
add_op = Operator('add_two', X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32")
......
import paddle.v2.framework.core as core
import unittest
import numpy
import paddle.v2.framework.create_op_creation_methods as creation
from paddle.v2.framework.op import Operator
class OpTestMeta(type):
......@@ -21,18 +21,14 @@ class OpTestMeta(type):
obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs)
def test_all(self):
func = getattr(creation.op_creations, self.type, None)
self.assertIsNotNone(func)
scope = core.Scope()
kwargs = dict()
places = []
places.append(core.CPUPlace())
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
for in_name in func.all_input_args:
for in_name in Operator.get_op_input_names(self.type):
if hasattr(self, in_name):
kwargs[in_name] = in_name
var = scope.new_var(in_name).get_tensor()
......@@ -42,23 +38,23 @@ class OpTestMeta(type):
else:
kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args:
for out_name in Operator.get_op_output_names(self.type):
if hasattr(self, out_name):
kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args:
for attr_name in Operator.get_op_attr_names(self.type):
if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name)
op = func(**kwargs)
op = Operator(self.type, **kwargs)
op.infer_shape(scope)
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
for out_name in func.all_output_args:
for out_name in Operator.get_op_output_names(self.type):
actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
......
......@@ -2,7 +2,7 @@ import unittest
import numpy
import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation
from paddle.v2.framework.op import Operator
from op_test_util import OpTestMeta
......@@ -19,7 +19,7 @@ class TestAddOp(unittest.TestCase):
class TestAddGradOp(unittest.TestCase):
def test_add_grad(self):
op = creation.op_creations.add_two(X="X", Y="Y", Out="Out")
op = Operator('add_two', X="X", Y="Y", Out="Out")
backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "add_two_grad")
expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
......
import paddle.v2.framework.core as core
import unittest
import numpy
import paddle.v2.framework.create_op_creation_methods as creation
from paddle.v2.framework.op import Operator
class TestFc(unittest.TestCase):
......@@ -24,7 +24,7 @@ class TestFc(unittest.TestCase):
# Set a real numpy array here.
# x_tensor.set(numpy.array([]))
op = creation.op_creations.fc(X="X", Y="Y", W="W")
op = Operator("fc", X="X", Y="Y", W="W")
for out in op.outputs():
if scope.find_var(out) is None:
......
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from paddle.v2.framework.op import Operator
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.Net.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
op1 = Operator("add_two", X="X", Y="Y", Out="Out")
net.add_op(op1)
net2 = core.Net.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
net2.add_op(Operator("fc", X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
net.complete_add_op(True)
......
from paddle.v2.framework.network import Network
import paddle.v2.framework.core as core
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = Network()
out = net.add_two(X="X", Y="Y")
fc_out = net.fc(X=out, W="w")
net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual(
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net))
net2 = Network()
tmp = net2.add_two(X="X", Y="Y")
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
if __name__ == '__main__':
unittest.main()
......@@ -2,7 +2,7 @@ import unittest
import numpy as np
import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation
from paddle.v2.framework.op import Operator
from op_test_util import OpTestMeta
......@@ -25,7 +25,7 @@ class TestSoftmaxOp(unittest.TestCase):
class TestSoftmaxGradOp(unittest.TestCase):
def test_softmax_grad(self):
op = creation.op_creations.softmax(X="X", Y="Y")
op = Operator('softmax', X="X", Y="Y")
backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "softmax_grad")
expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).'''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册