提交 0ab678e9 编写于 作者: Y Yu Yang

Add unittest for network

上级 c14f3e8f
...@@ -56,17 +56,6 @@ void ExposeOperator(ClassType& m) { ...@@ -56,17 +56,6 @@ void ExposeOperator(ClassType& m) {
.def("__str__", &ClassType::type::DebugString); .def("__str__", &ClassType::type::DebugString);
} }
template <typename ClassType>
void ExposeOperator(ClassType& m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("outputs",
[](const typename ClassType::type& op) -> std::vector<std::string> {
return op.outputs_;
})
.def("__str__", &ClassType::type::DebugString);
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
......
...@@ -2,13 +2,28 @@ import paddle.v2.framework.core as core ...@@ -2,13 +2,28 @@ import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope from default_scope_funcs import create_var, get_var, get_cur_scope
__all__ = ['Network'] # Only expose Network
class NetworkFunctor(object): class NetworkFunctor(object):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def __init__(self, func, net): def __init__(self, func, net):
self.func = func self.func = func
self.net = net self.net = net
def __call__(self, **kwargs): def __call__(self, *args, **kwargs):
if len(args) != 0:
raise ValueError("Paddle must use keyword argument")
inputs = self.func.all_input_args inputs = self.func.all_input_args
for ipt in inputs: for ipt in inputs:
if ipt in kwargs: if ipt in kwargs:
...@@ -58,6 +73,22 @@ class NetworkFunctor(object): ...@@ -58,6 +73,22 @@ class NetworkFunctor(object):
class Network(object): class Network(object):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def __init__(self): def __init__(self):
self.net = core.Net.create() self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations) funcs = (func_name for func_name in dir(op_creations)
...@@ -65,6 +96,9 @@ class Network(object): ...@@ -65,6 +96,9 @@ class Network(object):
self.generate_idx = 0 self.generate_idx = 0
self.var_name_map = dict() self.var_name_map = dict()
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for func_name in funcs: for func_name in funcs:
func = getattr(op_creations, func_name) func = getattr(op_creations, func_name)
impl = NetworkFunctor(func, self) impl = NetworkFunctor(func, self)
...@@ -92,5 +126,5 @@ if __name__ == '__main__': ...@@ -92,5 +126,5 @@ if __name__ == '__main__':
net = Network() net = Network()
out = net.add_two(X="a", Y="b") out = net.add_two(X="a", Y="b")
fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax") fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
net.complete_add_op()
print str(net) print net
...@@ -12,4 +12,5 @@ add_python_test(test_framework ...@@ -12,4 +12,5 @@ add_python_test(test_framework
test_mul_op.py test_mul_op.py
test_sigmoid_op.py test_sigmoid_op.py
test_softmax_op.py test_softmax_op.py
test_rowwise_add_op.py) test_rowwise_add_op.py
test_network.py)
from paddle.v2.framework.network import Network
import paddle.v2.framework.core as core
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = Network()
out = net.add_two(X="X", Y="Y")
fc_out = net.fc(X=out, W="w")
net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual(
'''Op(naive_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册