From 5c530ea827ff87ce901721d2c5a3c2dd3971e7e0 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 19 Dec 2017 17:30:09 +0800 Subject: [PATCH] export const value to python --- paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/const_value.cc | 29 +++++++++++++++++++ paddle/pybind/const_value.h | 26 +++++++++++++++++ paddle/pybind/pybind.cc | 2 ++ python/paddle/v2/fluid/framework.py | 18 ++++++++++-- .../v2/fluid/tests/test_batch_norm_op.py | 1 + .../paddle/v2/fluid/tests/test_const_value.py | 11 +++++++ python/paddle/v2/fluid/tests/test_operator.py | 2 +- python/paddle/v2/fluid/tests/test_program.py | 8 ++--- .../v2/fluid/tests/test_recurrent_op.py | 4 +-- 10 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 paddle/pybind/const_value.cc create mode 100644 paddle/pybind/const_value.h create mode 100644 python/paddle/v2/fluid/tests/test_const_value.py diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 1fb69de90d..6afed7eec7 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc + SRCS pybind.cc exception.cc protobuf.cc const_value.cc DEPS pybind python backward proto_desc paddle_memory executor prune init ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/const_value.cc b/paddle/pybind/const_value.cc new file mode 100644 index 0000000000..b13ad42ea2 --- /dev/null +++ b/paddle/pybind/const_value.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "const_value.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace pybind { + +void BindConstValue(pybind11::module& m) { + m.def("kEmptyVarName", [] { return framework::kEmptyVarName; }); + m.def("kTempVarName", [] { return framework::kTempVarName; }); + m.def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); + m.def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/const_value.h b/paddle/pybind/const_value.h new file mode 100644 index 0000000000..3d57c972a9 --- /dev/null +++ b/paddle/pybind/const_value.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/platform/enforce.h" +#include "pybind11/pybind11.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +extern void BindConstValue(pybind11::module& m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 4248db34c6..4a82f1596e 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/operators/net_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/pybind/const_value.h" #include "paddle/pybind/exception.h" #include "paddle/pybind/pybind.h" #include "paddle/pybind/tensor_py.h" @@ -431,6 +432,7 @@ All parameter, weight, gradient are variables in Paddle. BindBlockDesc(m); BindVarDsec(m); BindOpDesc(m); + BindConstValue(m); py::class_(m, "LodRankTable") .def("items", [](framework::LoDRankTable &table) { diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index bf0cd275b6..8deb6aaf7a 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -1,10 +1,10 @@ import collections +import contextlib import numpy as np -from . import core + import proto.framework_pb2 as framework_pb2 -import google.protobuf.message -import contextlib +from . import core __all__ = [ 'Block', 'Variable', 'Program', 'Operator', 'default_startup_program', @@ -12,6 +12,18 @@ __all__ = [ 'switch_main_program' ] +EMPTY_VAR_NAME = core.kEmptyVarName() +TEMP_VAR_NAME = core.kTempVarName() +GRAD_VAR_SUFFIX = core.kGradVarSuffix() +ZERO_VAR_SUFFIX = core.kZeroVarSuffix() + + +def grad_var_name(var_name): + """ + return gradient name for a certain var name + """ + return var_name + GRAD_VAR_SUFFIX + def unique_name(prefix): """ diff --git a/python/paddle/v2/fluid/tests/test_batch_norm_op.py b/python/paddle/v2/fluid/tests/test_batch_norm_op.py index e766a68c0e..a0c2da113e 100644 --- a/python/paddle/v2/fluid/tests/test_batch_norm_op.py +++ b/python/paddle/v2/fluid/tests/test_batch_norm_op.py @@ -3,6 +3,7 @@ import numpy as np from op_test import OpTest import paddle.v2.fluid.core as core from paddle.v2.fluid.op import Operator +from paddle.v2.fluid.framewor import grad_var_name def grad_var_name(var_name): diff --git a/python/paddle/v2/fluid/tests/test_const_value.py b/python/paddle/v2/fluid/tests/test_const_value.py new file mode 100644 index 0000000000..fd034f55e7 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_const_value.py @@ -0,0 +1,11 @@ +import unittest +import paddle.v2.fluid.framework as framework + + +class ConditionalBlock(unittest.TestCase): + def test_const_value(self): + self.assertEqual(framework.GRAD_VAR_SUFFIX, "@GRAD") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_operator.py b/python/paddle/v2/fluid/tests/test_operator.py index 4aa022ef90..c059a2b88b 100644 --- a/python/paddle/v2/fluid/tests/test_operator.py +++ b/python/paddle/v2/fluid/tests/test_operator.py @@ -1,6 +1,6 @@ import unittest + import paddle.v2.fluid.op as op -import paddle.v2.fluid.core as core import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 diff --git a/python/paddle/v2/fluid/tests/test_program.py b/python/paddle/v2/fluid/tests/test_program.py index e6da0b2be7..447c746aac 100644 --- a/python/paddle/v2/fluid/tests/test_program.py +++ b/python/paddle/v2/fluid/tests/test_program.py @@ -1,7 +1,7 @@ from __future__ import print_function import unittest -from paddle.v2.fluid.framework import Program, default_main_program, program_guard +from paddle.v2.fluid.framework import Program, default_main_program, program_guard, grad_var_name import paddle.v2.fluid.layers as layers main_program = default_main_program() @@ -109,12 +109,10 @@ class TestProgram(unittest.TestCase): self.assertEqual(add_op.idx, 1) param_to_grad = prog.append_backward(mean_out, set()) - def grad_name(name): - return name + "@GRAD" - for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out", "mean.out"): - self.assertEqual(param_to_grad[var_name][0], grad_name(var_name)) + self.assertEqual(param_to_grad[var_name][0], + grad_var_name(var_name)) self.assertEqual(param_to_grad[var_name][1], 0) expect_ops = [ diff --git a/python/paddle/v2/fluid/tests/test_recurrent_op.py b/python/paddle/v2/fluid/tests/test_recurrent_op.py index 694ff0d8dd..e38c763ddb 100644 --- a/python/paddle/v2/fluid/tests/test_recurrent_op.py +++ b/python/paddle/v2/fluid/tests/test_recurrent_op.py @@ -1,7 +1,7 @@ import unittest import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.framework import Program, grad_var_name from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.backward import append_backward_ops import numpy as np @@ -164,7 +164,7 @@ class RecurrentOpTest1(unittest.TestCase): for x in self.data_field } fetch_list = [ - self.main_program.global_block().var(x + "@GRAD") + self.main_program.global_block().var(grad_var_name(x)) for x in self.data_field ] -- GitLab