diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 62539c10760d1deabed46e2c388b3f9efc7cfb2f..43c52957a1e3d21866a0d68e40cb0f764c53b937 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,17 +56,6 @@ void ExposeOperator(ClassType& m) { .def("__str__", &ClassType::type::DebugString); } -template -void ExposeOperator(ClassType& m) { - m.def("infer_shape", &ClassType::type::InferShape) - .def("run", &ClassType::type::Run) - .def("outputs", - [](const typename ClassType::type& op) -> std::vector { - return op.outputs_; - }) - .def("__str__", &ClassType::type::DebugString); -} - PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); diff --git a/python/paddle/v2/framework/network.py b/python/paddle/v2/framework/network.py index bde48851a0959c720424ae9e4f53515046820232..ade7e4b85e67c1a8509b9b7c3bb421aac0fe36de 100644 --- a/python/paddle/v2/framework/network.py +++ b/python/paddle/v2/framework/network.py @@ -2,13 +2,28 @@ import paddle.v2.framework.core as core from paddle.v2.framework.create_op_creation_methods import op_creations from default_scope_funcs import create_var, get_var, get_cur_scope +__all__ = ['Network'] # Only expose Network + class NetworkFunctor(object): + """ + Network Op Creation Function. Used internally in this module. + It convert string input to Variable. If it is not created before, just + create in scope. + + It is a functor object. means the instances are callable. + + :param func: The op creation function which generated in Python. + :param net: The Network instance. + """ + def __init__(self, func, net): self.func = func self.net = net - def __call__(self, **kwargs): + def __call__(self, *args, **kwargs): + if len(args) != 0: + raise ValueError("Paddle must use keyword argument") inputs = self.func.all_input_args for ipt in inputs: if ipt in kwargs: @@ -58,6 +73,22 @@ class NetworkFunctor(object): class Network(object): + """ + The network concept. It avoid user to manually create operator, create + variable, and combine them into a Net. Just use Network.xxx can create the + operator, create variables in default scope, and add them into `self.net`. + + For example: + + .. code-block: python + + net = Network() + out = net.add_two(X="a", Y="b") + fc_out = net.fc(X="out", W="fc.w") + + net.run(...) + """ + def __init__(self): self.net = core.Net.create() funcs = (func_name for func_name in dir(op_creations) @@ -65,6 +96,9 @@ class Network(object): self.generate_idx = 0 self.var_name_map = dict() + # TODO(yuyang18): This code can work, but do not generate a good + # docstring, try to give a better way generate function in runtime + # later. for func_name in funcs: func = getattr(op_creations, func_name) impl = NetworkFunctor(func, self) @@ -92,5 +126,5 @@ if __name__ == '__main__': net = Network() out = net.add_two(X="a", Y="b") fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax") - - print str(net) + net.complete_add_op() + print net diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index b3eb2ef8a8966318fe33ca8b3032a4120c73909f..7d1229a34c50548db69c5a720f0bf333bd0dd58d 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -12,4 +12,5 @@ add_python_test(test_framework test_mul_op.py test_sigmoid_op.py test_softmax_op.py - test_rowwise_add_op.py) + test_rowwise_add_op.py + test_network.py) diff --git a/python/paddle/v2/framework/tests/test_network.py b/python/paddle/v2/framework/tests/test_network.py new file mode 100644 index 0000000000000000000000000000000000000000..310290e34b0d5056de6eb719a84352078057e419 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_network.py @@ -0,0 +1,23 @@ +from paddle.v2.framework.network import Network +import paddle.v2.framework.core as core +import unittest + + +class TestNet(unittest.TestCase): + def test_net_all(self): + net = Network() + out = net.add_two(X="X", Y="Y") + fc_out = net.fc(X=out, W="w") + net.complete_add_op() + self.assertTrue(isinstance(fc_out, core.Variable)) + self.assertEqual( + '''Op(naive_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1). + Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0). + Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0). + Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0). + Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1). +''', str(net)) + + +if __name__ == '__main__': + unittest.main()