From b9f8ae849451622573f02c3fc6fa54cdfc00924b Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 25 Nov 2019 10:44:38 +0800 Subject: [PATCH] Add global value getter setter (#21285) * add global value getter setter, test=develop * fix error messages, test=develop --- paddle/fluid/pybind/CMakeLists.txt | 1 + .../pybind/global_value_getter_setter.cc | 168 ++++++++++++++++++ .../fluid/pybind/global_value_getter_setter.h | 25 +++ paddle/fluid/pybind/pybind.cc | 6 +- .../test_mkldnn_int8_quantization_strategy.py | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 6 +- .../test_global_var_getter_setter.py | 62 +++++++ 7 files changed, 262 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/pybind/global_value_getter_setter.cc create mode 100644 paddle/fluid/pybind/global_value_getter_setter.h create mode 100644 python/paddle/fluid/tests/unittests/test_global_var_getter_setter.py diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 9563e7b6fe..836f15e18a 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -15,6 +15,7 @@ set(PYBIND_SRCS exception.cc protobuf.cc const_value.cc + global_value_getter_setter.cc reader_py.cc fleet_wrapper_py.cc box_helper_py.cc diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc new file mode 100644 index 0000000000..108764f9bf --- /dev/null +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/global_value_getter_setter.h" +#include +#include +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "paddle/fluid/framework/python_headers.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/macros.h" +#include "pybind11/stl.h" + +DECLARE_double(eager_delete_tensor_gb); +DECLARE_bool(use_mkldnn); +DECLARE_bool(use_ngraph); + +namespace paddle { +namespace pybind { + +namespace py = pybind11; + +class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { + DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry); + + GlobalVarGetterSetterRegistry() = default; + + public: + using Setter = std::function; + using Getter = std::function; + + static const GlobalVarGetterSetterRegistry &Instance() { return instance_; } + + static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; } + + void RegisterGetter(const std::string &name, Getter func) { + PADDLE_ENFORCE_EQ( + getters_.count(name), 0, + platform::errors::AlreadyExists( + "Getter of global variable %s has been registered", name)); + PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument( + "Getter of %s should not be null", name)); + getters_[name] = std::move(func); + } + + void RegisterSetter(const std::string &name, Setter func) { + PADDLE_ENFORCE_EQ( + HasGetterMethod(name), true, + platform::errors::NotFound( + "Cannot register setter for %s before register getter", name)); + + PADDLE_ENFORCE_EQ( + setters_.count(name), 0, + platform::errors::AlreadyExists( + "Setter of global variable %s has been registered", name)); + PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument( + "Setter of %s should not be null", name)); + setters_[name] = std::move(func); + } + + const Getter &GetterMethod(const std::string &name) const { + PADDLE_ENFORCE_EQ( + HasGetterMethod(name), true, + platform::errors::NotFound("Cannot find global variable %s", name)); + return getters_.at(name); + } + + py::object GetOrReturnDefaultValue(const std::string &name, + const py::object &default_value) { + if (HasGetterMethod(name)) { + return GetterMethod(name)(); + } else { + return default_value; + } + } + + py::object Get(const std::string &name) { return GetterMethod(name)(); } + + const Setter &SetterMethod(const std::string &name) const { + PADDLE_ENFORCE_EQ( + HasSetterMethod(name), true, + platform::errors::NotFound("Global variable %s is not writable", name)); + return setters_.at(name); + } + + void Set(const std::string &name, const py::object &value) const { + SetterMethod(name)(value); + } + + bool HasGetterMethod(const std::string &name) const { + return getters_.count(name) > 0; + } + + bool HasSetterMethod(const std::string &name) const { + return setters_.count(name) > 0; + } + + std::unordered_set Keys() const { + std::unordered_set keys; + keys.reserve(getters_.size()); + for (auto &pair : getters_) { + keys.insert(pair.first); + } + return keys; + } + + private: + static GlobalVarGetterSetterRegistry instance_; + + std::unordered_map getters_; + std::unordered_map setters_; +}; + +GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_; + +static void RegisterGlobalVarGetterSetter(); + +void BindGlobalValueGetterSetter(pybind11::module *module) { + RegisterGlobalVarGetterSetter(); + + py::class_(*module, + "GlobalVarGetterSetterRegistry") + .def("__getitem__", &GlobalVarGetterSetterRegistry::Get) + .def("__setitem__", &GlobalVarGetterSetterRegistry::Set) + .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod) + .def("keys", &GlobalVarGetterSetterRegistry::Keys) + .def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue, + py::arg("key"), py::arg("default") = py::cast(Py_None)); + + module->def("globals", &GlobalVarGetterSetterRegistry::Instance, + py::return_value_policy::reference); +} + +#define REGISTER_GLOBAL_VAR_GETTER_ONLY(var) \ + GlobalVarGetterSetterRegistry::MutableInstance()->RegisterGetter( \ + #var, []() -> py::object { return py::cast(var); }) + +#define REGISTER_GLOBAL_VAR_SETTER_ONLY(var) \ + GlobalVarGetterSetterRegistry::MutableInstance()->RegisterSetter( \ + #var, [](const py::object &obj) { var = py::cast(obj); }) + +#define REGISTER_GLOBAL_VAR_GETTER_SETTER(var) \ + REGISTER_GLOBAL_VAR_GETTER_ONLY(var); \ + REGISTER_GLOBAL_VAR_SETTER_ONLY(var) + +static void RegisterGlobalVarGetterSetter() { + REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_mkldnn); + REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph); + REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/global_value_getter_setter.h b/paddle/fluid/pybind/global_value_getter_setter.h new file mode 100644 index 0000000000..86a27f7286 --- /dev/null +++ b/paddle/fluid/pybind/global_value_getter_setter.h @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace pybind { + +void BindGlobalValueGetterSetter(pybind11::module *module); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 17ee63349a..86635f15fc 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -61,6 +61,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/fleet_wrapper_py.h" +#include "paddle/fluid/pybind/global_value_getter_setter.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/ir.h" @@ -1049,10 +1050,6 @@ All parameter, weight, gradient are variables in Paddle. m.def("has_infer_inplace", [](const std::string op_type) { return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace(); }); - m.def("get_flags_use_mkldnn", []() { return FLAGS_use_mkldnn; }); -#ifdef PADDLE_WITH_NGRAPH - m.def("get_flags_use_ngraph", []() { return FLAGS_use_ngraph; }); -#endif m.def("prune", [](const ProgramDesc &origin, const std::set &feeded_var_names, @@ -1405,6 +1402,7 @@ All parameter, weight, gradient are variables in Paddle. BindVarDsec(&m); BindOpDesc(&m); BindConstValue(&m); + BindGlobalValueGetterSetter(&m); py::class_(m, "LodRankTable") .def("items", [](framework::LoDRankTable &table) { diff --git a/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py index d41ea34907..600880d792 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py @@ -162,7 +162,7 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): fetch_targets] = fluid.io.load_inference_model( model_path, exe, 'model', 'params') - use_mkldnn = fluid.core.get_flags_use_mkldnn() + use_mkldnn = fluid.core.globals()["FLAGS_use_mkldnn"] if (use_mkldnn): graph = IrGraph( core.Graph(inference_program.desc), for_test=True) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 7a626984d5..723fd48f28 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -755,7 +755,7 @@ class OpTest(unittest.TestCase): else: # TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn/ngraph # skip op that use_mkldnn and use_ngraph currently - flags_use_mkldnn = fluid.core.get_flags_use_mkldnn() + flags_use_mkldnn = fluid.core.globals()["FLAGS_use_mkldnn"] attrs_use_mkldnn = hasattr( self, 'attrs') and bool(self.attrs.get('use_mkldnn', False)) @@ -765,7 +765,7 @@ class OpTest(unittest.TestCase): ) continue use_ngraph = fluid.core.is_compiled_with_ngraph( - ) and fluid.core.get_flags_use_ngraph() + ) and fluid.core.globals()["FLAGS_use_ngraph"] if use_ngraph: warnings.warn( "check inplace_grad for ops using ngraph is not supported" @@ -977,7 +977,7 @@ class OpTest(unittest.TestCase): places = [fluid.CPUPlace()] cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False use_ngraph = fluid.core.is_compiled_with_ngraph( - ) and fluid.core.get_flags_use_ngraph() + ) and fluid.core.globals()['FLAGS_use_ngraph'] if use_ngraph: cpu_only = True if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\ diff --git a/python/paddle/fluid/tests/unittests/test_global_var_getter_setter.py b/python/paddle/fluid/tests/unittests/test_global_var_getter_setter.py new file mode 100644 index 0000000000..f6ad1b3082 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_global_var_getter_setter.py @@ -0,0 +1,62 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest as unittest + + +class VarInfo(object): + def __init__(self, var_name, var_type, writable): + self.name = var_name + self.type = var_type + self.writable = writable + + +class TestGlobalVarGetterSetter(unittest.TestCase): + def test_main(self): + var_infos = [ + VarInfo("FLAGS_use_mkldnn", bool, False), + VarInfo("FLAGS_use_ngraph", bool, False), + VarInfo("FLAGS_eager_delete_tensor_gb", float, True), + ] + + g = fluid.core.globals() + for var in var_infos: + self.assertTrue(var.name in g) + self.assertTrue(var.name in g.keys()) + value1 = g[var.name] + value2 = g.get(var.name, None) + self.assertTrue(value1 is not None) + self.assertEqual(value1, value2) + self.assertEqual(type(value1), var.type) + self.assertEqual(type(value2), var.type) + + if var.writable: + g[var.name] = -1 + else: + try: + g[var.name] = False + self.assertTrue(False) + except: + self.assertTrue(True) + + name = "__any_non_exist_name__" + self.assertFalse(name in g) + self.assertFalse(name in g.keys()) + self.assertTrue(g.get(name, None) is None) + self.assertEquals(g.get(name, -1), -1) + + +if __name__ == '__main__': + unittest.main() -- GitLab