未验证 提交 b9f8ae84 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add global value getter setter (#21285)

* add global value getter setter, test=develop

* fix error messages, test=develop
上级 b19e1a1b
......@@ -15,6 +15,7 @@ set(PYBIND_SRCS
exception.cc
protobuf.cc
const_value.cc
global_value_getter_setter.cc
reader_py.cc
fleet_wrapper_py.cc
box_helper_py.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "pybind11/stl.h"
DECLARE_double(eager_delete_tensor_gb);
DECLARE_bool(use_mkldnn);
DECLARE_bool(use_ngraph);
namespace paddle {
namespace pybind {
namespace py = pybind11;
class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry);
GlobalVarGetterSetterRegistry() = default;
public:
using Setter = std::function<void(const py::object &)>;
using Getter = std::function<py::object()>;
static const GlobalVarGetterSetterRegistry &Instance() { return instance_; }
static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }
void RegisterGetter(const std::string &name, Getter func) {
PADDLE_ENFORCE_EQ(
getters_.count(name), 0,
platform::errors::AlreadyExists(
"Getter of global variable %s has been registered", name));
PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument(
"Getter of %s should not be null", name));
getters_[name] = std::move(func);
}
void RegisterSetter(const std::string &name, Setter func) {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name), true,
platform::errors::NotFound(
"Cannot register setter for %s before register getter", name));
PADDLE_ENFORCE_EQ(
setters_.count(name), 0,
platform::errors::AlreadyExists(
"Setter of global variable %s has been registered", name));
PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument(
"Setter of %s should not be null", name));
setters_[name] = std::move(func);
}
const Getter &GetterMethod(const std::string &name) const {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name), true,
platform::errors::NotFound("Cannot find global variable %s", name));
return getters_.at(name);
}
py::object GetOrReturnDefaultValue(const std::string &name,
const py::object &default_value) {
if (HasGetterMethod(name)) {
return GetterMethod(name)();
} else {
return default_value;
}
}
py::object Get(const std::string &name) { return GetterMethod(name)(); }
const Setter &SetterMethod(const std::string &name) const {
PADDLE_ENFORCE_EQ(
HasSetterMethod(name), true,
platform::errors::NotFound("Global variable %s is not writable", name));
return setters_.at(name);
}
void Set(const std::string &name, const py::object &value) const {
SetterMethod(name)(value);
}
bool HasGetterMethod(const std::string &name) const {
return getters_.count(name) > 0;
}
bool HasSetterMethod(const std::string &name) const {
return setters_.count(name) > 0;
}
std::unordered_set<std::string> Keys() const {
std::unordered_set<std::string> keys;
keys.reserve(getters_.size());
for (auto &pair : getters_) {
keys.insert(pair.first);
}
return keys;
}
private:
static GlobalVarGetterSetterRegistry instance_;
std::unordered_map<std::string, Getter> getters_;
std::unordered_map<std::string, Setter> setters_;
};
GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_;
static void RegisterGlobalVarGetterSetter();
void BindGlobalValueGetterSetter(pybind11::module *module) {
RegisterGlobalVarGetterSetter();
py::class_<GlobalVarGetterSetterRegistry>(*module,
"GlobalVarGetterSetterRegistry")
.def("__getitem__", &GlobalVarGetterSetterRegistry::Get)
.def("__setitem__", &GlobalVarGetterSetterRegistry::Set)
.def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
.def("keys", &GlobalVarGetterSetterRegistry::Keys)
.def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
py::arg("key"), py::arg("default") = py::cast<py::none>(Py_None));
module->def("globals", &GlobalVarGetterSetterRegistry::Instance,
py::return_value_policy::reference);
}
#define REGISTER_GLOBAL_VAR_GETTER_ONLY(var) \
GlobalVarGetterSetterRegistry::MutableInstance()->RegisterGetter( \
#var, []() -> py::object { return py::cast(var); })
#define REGISTER_GLOBAL_VAR_SETTER_ONLY(var) \
GlobalVarGetterSetterRegistry::MutableInstance()->RegisterSetter( \
#var, [](const py::object &obj) { var = py::cast<decltype(var)>(obj); })
#define REGISTER_GLOBAL_VAR_GETTER_SETTER(var) \
REGISTER_GLOBAL_VAR_GETTER_ONLY(var); \
REGISTER_GLOBAL_VAR_SETTER_ONLY(var)
static void RegisterGlobalVarGetterSetter() {
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_mkldnn);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb);
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
namespace paddle {
namespace pybind {
void BindGlobalValueGetterSetter(pybind11::module *module);
} // namespace pybind
} // namespace paddle
......@@ -61,6 +61,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h"
......@@ -1049,10 +1050,6 @@ All parameter, weight, gradient are variables in Paddle.
m.def("has_infer_inplace", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace();
});
m.def("get_flags_use_mkldnn", []() { return FLAGS_use_mkldnn; });
#ifdef PADDLE_WITH_NGRAPH
m.def("get_flags_use_ngraph", []() { return FLAGS_use_ngraph; });
#endif
m.def("prune", [](const ProgramDesc &origin,
const std::set<std::string> &feeded_var_names,
......@@ -1405,6 +1402,7 @@ All parameter, weight, gradient are variables in Paddle.
BindVarDsec(&m);
BindOpDesc(&m);
BindConstValue(&m);
BindGlobalValueGetterSetter(&m);
py::class_<framework::LoDRankTable>(m, "LodRankTable")
.def("items", [](framework::LoDRankTable &table) {
......
......@@ -162,7 +162,7 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
use_mkldnn = fluid.core.get_flags_use_mkldnn()
use_mkldnn = fluid.core.globals()["FLAGS_use_mkldnn"]
if (use_mkldnn):
graph = IrGraph(
core.Graph(inference_program.desc), for_test=True)
......
......@@ -755,7 +755,7 @@ class OpTest(unittest.TestCase):
else:
# TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn/ngraph
# skip op that use_mkldnn and use_ngraph currently
flags_use_mkldnn = fluid.core.get_flags_use_mkldnn()
flags_use_mkldnn = fluid.core.globals()["FLAGS_use_mkldnn"]
attrs_use_mkldnn = hasattr(
self,
'attrs') and bool(self.attrs.get('use_mkldnn', False))
......@@ -765,7 +765,7 @@ class OpTest(unittest.TestCase):
)
continue
use_ngraph = fluid.core.is_compiled_with_ngraph(
) and fluid.core.get_flags_use_ngraph()
) and fluid.core.globals()["FLAGS_use_ngraph"]
if use_ngraph:
warnings.warn(
"check inplace_grad for ops using ngraph is not supported"
......@@ -977,7 +977,7 @@ class OpTest(unittest.TestCase):
places = [fluid.CPUPlace()]
cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False
use_ngraph = fluid.core.is_compiled_with_ngraph(
) and fluid.core.get_flags_use_ngraph()
) and fluid.core.globals()['FLAGS_use_ngraph']
if use_ngraph:
cpu_only = True
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest as unittest
class VarInfo(object):
def __init__(self, var_name, var_type, writable):
self.name = var_name
self.type = var_type
self.writable = writable
class TestGlobalVarGetterSetter(unittest.TestCase):
def test_main(self):
var_infos = [
VarInfo("FLAGS_use_mkldnn", bool, False),
VarInfo("FLAGS_use_ngraph", bool, False),
VarInfo("FLAGS_eager_delete_tensor_gb", float, True),
]
g = fluid.core.globals()
for var in var_infos:
self.assertTrue(var.name in g)
self.assertTrue(var.name in g.keys())
value1 = g[var.name]
value2 = g.get(var.name, None)
self.assertTrue(value1 is not None)
self.assertEqual(value1, value2)
self.assertEqual(type(value1), var.type)
self.assertEqual(type(value2), var.type)
if var.writable:
g[var.name] = -1
else:
try:
g[var.name] = False
self.assertTrue(False)
except:
self.assertTrue(True)
name = "__any_non_exist_name__"
self.assertFalse(name in g)
self.assertFalse(name in g.keys())
self.assertTrue(g.get(name, None) is None)
self.assertEquals(g.get(name, -1), -1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册