From 102a5f349926539c256afca54108241cc5e313c6 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 19 Oct 2017 19:43:31 -0700 Subject: [PATCH] Feature/remove global scope (#4950) * Unify `set_feed_variable` to one method * Move global scope to python, not in C++ --- paddle/framework/feed_fetch_method.h | 11 ++++++----- paddle/framework/scope.cc | 8 -------- paddle/framework/scope.h | 3 --- paddle/pybind/pybind.cc | 12 +++++------- python/paddle/v2/framework/executor.py | 14 ++++++++++---- .../v2/framework/tests/test_feed_fetch_method.py | 5 +++-- 6 files changed, 24 insertions(+), 29 deletions(-) diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h index 3ef70043d6a..7feacb1e247 100644 --- a/paddle/framework/feed_fetch_method.h +++ b/paddle/framework/feed_fetch_method.h @@ -21,12 +21,12 @@ limitations under the License. */ namespace paddle { namespace framework { -void SetFeedVariable(const LoDTensor& input, const std::string& var_name, - size_t index) { +void SetFeedVariable(Scope* scope, const LoDTensor& input, + const std::string& var_name, size_t index) { // If var_name Variable is not found in GlobalScope, a new variable will // be created. VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; - Variable* g_feed_value = GetGlobalScope().Var(var_name); + Variable* g_feed_value = scope->Var(var_name); auto& feed_inputs = *(g_feed_value->GetMutable>()); if (index >= feed_inputs.size()) { @@ -38,10 +38,11 @@ void SetFeedVariable(const LoDTensor& input, const std::string& var_name, feed_inputs[index].set_lod(input.lod()); } -LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) { +LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, + size_t index) { // Since we want to fetch LodTensor from a variable, the variable must // be created alreadly. - Variable* g_fetch_value = GetGlobalScope().FindVar(var_name); + Variable* g_fetch_value = scope.FindVar(var_name); PADDLE_ENFORCE(g_fetch_value->IsType(), "Only %s can be invoked by GetFetchVariable", typeid(FeedFetchList).name()); diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index b8e116c430d..ac3ac649f96 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -72,13 +72,5 @@ void Scope::DeleteScope(Scope* scope) { delete scope; } -framework::Scope& GetGlobalScope() { - static framework::Scope* g_scope = nullptr; - if (g_scope == nullptr) { - g_scope = new framework::Scope(); - } - return *g_scope; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 78ff136ee17..7206b53068b 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -74,8 +74,5 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; - -framework::Scope& GetGlobalScope(); - } // namespace framework } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 94c9706f794..9ef47b88fd0 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -219,8 +219,7 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init<>()) .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, py::return_value_policy::reference) - .def("drop_kids", &Scope::DropKids) - .def_static("global_scope", &GetGlobalScope); + .def("drop_kids", &Scope::DropKids); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. @@ -451,11 +450,10 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init &>()) - .def("run", - [](Executor &self, ProgramDescBind *program_bind, int block_id) { - framework::Scope &global_scope = GetGlobalScope(); - self.Run(*program_bind->Proto(), &global_scope, block_id); - }); + .def("run", [](Executor &self, ProgramDescBind *program_bind, + Scope *scope, int block_id) { + self.Run(*program_bind->Proto(), scope, block_id); + }); m.def("unique_integer", UniqueIntegerGenerator); diff --git a/python/paddle/v2/framework/executor.py b/python/paddle/v2/framework/executor.py index 1adc10c233e..82b83d4bb6a 100644 --- a/python/paddle/v2/framework/executor.py +++ b/python/paddle/v2/framework/executor.py @@ -1,6 +1,8 @@ import paddle.v2.framework.core as core from paddle.v2.framework.framework import Block, Program +g_scope = core.Scope() + class Executor(object): def __init__(self, places): @@ -20,10 +22,14 @@ class Executor(object): feed, fetch_list, feed_var_name='feed', - fetch_var_name='fetch'): + fetch_var_name='fetch', + scope=None): if not isinstance(program, Program): raise TypeError() + if scope is None: + scope = g_scope + program = program.clone() global_block = program.global_block() feed_var = global_block.create_var( @@ -38,7 +44,7 @@ class Executor(object): inputs={'X': [feed_var]}, outputs={'Out': [out]}, attrs={'col': i}) - core.set_feed_variable(feed[name], feed_var.name, i) + core.set_feed_variable(scope, feed[name], feed_var.name, i) fetch_var = global_block.create_var( name=fetch_var_name, @@ -51,8 +57,8 @@ class Executor(object): outputs={'Out': [fetch_var]}, attrs={'col': i}) - self.executor.run(program.desc, 0) + self.executor.run(program.desc, scope, 0) return [ - core.get_fetch_variable(fetch_var_name, i) + core.get_fetch_variable(scope, fetch_var_name, i) for i in xrange(len(fetch_list)) ] diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/framework/tests/test_feed_fetch_method.py index 8b9b44440d6..fbd659ece01 100644 --- a/python/paddle/v2/framework/tests/test_feed_fetch_method.py +++ b/python/paddle/v2/framework/tests/test_feed_fetch_method.py @@ -5,6 +5,7 @@ import numpy as np class TestFeedFetch(unittest.TestCase): def test_feed_fetch(self): + scope = core.Scope() place = core.CPUPlace() input_array = np.ones((4, 4, 6)).astype("float32") input_array[0, 0, 0] = 3 @@ -12,9 +13,9 @@ class TestFeedFetch(unittest.TestCase): input_tensor = core.LoDTensor([[0, 2, 4]]) input_tensor.set(input_array, place) - core.set_feed_variable(input_tensor, "feed", 0) + core.set_feed_variable(scope, input_tensor, "feed", 0) - output_tensor = core.get_fetch_variable("feed", 0) + output_tensor = core.get_fetch_variable(scope, "feed", 0) output_lod = output_tensor.lod() self.assertEqual(0, output_lod[0][0]) -- GitLab