提交 102a5f34 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/remove global scope (#4950)

* Unify `set_feed_variable` to one method

* Move global scope to python, not in C++
上级 9903e49f
...@@ -21,12 +21,12 @@ limitations under the License. */ ...@@ -21,12 +21,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void SetFeedVariable(const LoDTensor& input, const std::string& var_name, void SetFeedVariable(Scope* scope, const LoDTensor& input,
size_t index) { const std::string& var_name, size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will // If var_name Variable is not found in GlobalScope, a new variable will
// be created. // be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = GetGlobalScope().Var(var_name); Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs = auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>()); *(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
if (index >= feed_inputs.size()) { if (index >= feed_inputs.size()) {
...@@ -38,10 +38,11 @@ void SetFeedVariable(const LoDTensor& input, const std::string& var_name, ...@@ -38,10 +38,11 @@ void SetFeedVariable(const LoDTensor& input, const std::string& var_name,
feed_inputs[index].set_lod(input.lod()); feed_inputs[index].set_lod(input.lod());
} }
LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) { LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index) {
// Since we want to fetch LodTensor from a variable, the variable must // Since we want to fetch LodTensor from a variable, the variable must
// be created alreadly. // be created alreadly.
Variable* g_fetch_value = GetGlobalScope().FindVar(var_name); Variable* g_fetch_value = scope.FindVar(var_name);
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(), PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
"Only %s can be invoked by GetFetchVariable", "Only %s can be invoked by GetFetchVariable",
typeid(FeedFetchList).name()); typeid(FeedFetchList).name());
......
...@@ -72,13 +72,5 @@ void Scope::DeleteScope(Scope* scope) { ...@@ -72,13 +72,5 @@ void Scope::DeleteScope(Scope* scope) {
delete scope; delete scope;
} }
framework::Scope& GetGlobalScope() {
static framework::Scope* g_scope = nullptr;
if (g_scope == nullptr) {
g_scope = new framework::Scope();
}
return *g_scope;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -74,8 +74,5 @@ class Scope { ...@@ -74,8 +74,5 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
}; };
framework::Scope& GetGlobalScope();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -219,8 +219,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -219,8 +219,7 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>()) .def(py::init<>())
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("drop_kids", &Scope::DropKids) .def("drop_kids", &Scope::DropKids);
.def_static("global_scope", &GetGlobalScope);
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
...@@ -451,10 +450,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -451,10 +450,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>()) .def(py::init<std::vector<platform::Place> &>())
.def("run", .def("run", [](Executor &self, ProgramDescBind *program_bind,
[](Executor &self, ProgramDescBind *program_bind, int block_id) { Scope *scope, int block_id) {
framework::Scope &global_scope = GetGlobalScope(); self.Run(*program_bind->Proto(), scope, block_id);
self.Run(*program_bind->Proto(), &global_scope, block_id);
}); });
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
......
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program from paddle.v2.framework.framework import Block, Program
g_scope = core.Scope()
class Executor(object): class Executor(object):
def __init__(self, places): def __init__(self, places):
...@@ -20,10 +22,14 @@ class Executor(object): ...@@ -20,10 +22,14 @@ class Executor(object):
feed, feed,
fetch_list, fetch_list,
feed_var_name='feed', feed_var_name='feed',
fetch_var_name='fetch'): fetch_var_name='fetch',
scope=None):
if not isinstance(program, Program): if not isinstance(program, Program):
raise TypeError() raise TypeError()
if scope is None:
scope = g_scope
program = program.clone() program = program.clone()
global_block = program.global_block() global_block = program.global_block()
feed_var = global_block.create_var( feed_var = global_block.create_var(
...@@ -38,7 +44,7 @@ class Executor(object): ...@@ -38,7 +44,7 @@ class Executor(object):
inputs={'X': [feed_var]}, inputs={'X': [feed_var]},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'col': i}) attrs={'col': i})
core.set_feed_variable(feed[name], feed_var.name, i) core.set_feed_variable(scope, feed[name], feed_var.name, i)
fetch_var = global_block.create_var( fetch_var = global_block.create_var(
name=fetch_var_name, name=fetch_var_name,
...@@ -51,8 +57,8 @@ class Executor(object): ...@@ -51,8 +57,8 @@ class Executor(object):
outputs={'Out': [fetch_var]}, outputs={'Out': [fetch_var]},
attrs={'col': i}) attrs={'col': i})
self.executor.run(program.desc, 0) self.executor.run(program.desc, scope, 0)
return [ return [
core.get_fetch_variable(fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list)) for i in xrange(len(fetch_list))
] ]
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
class TestFeedFetch(unittest.TestCase): class TestFeedFetch(unittest.TestCase):
def test_feed_fetch(self): def test_feed_fetch(self):
scope = core.Scope()
place = core.CPUPlace() place = core.CPUPlace()
input_array = np.ones((4, 4, 6)).astype("float32") input_array = np.ones((4, 4, 6)).astype("float32")
input_array[0, 0, 0] = 3 input_array[0, 0, 0] = 3
...@@ -12,9 +13,9 @@ class TestFeedFetch(unittest.TestCase): ...@@ -12,9 +13,9 @@ class TestFeedFetch(unittest.TestCase):
input_tensor = core.LoDTensor([[0, 2, 4]]) input_tensor = core.LoDTensor([[0, 2, 4]])
input_tensor.set(input_array, place) input_tensor.set(input_array, place)
core.set_feed_variable(input_tensor, "feed", 0) core.set_feed_variable(scope, input_tensor, "feed", 0)
output_tensor = core.get_fetch_variable("feed", 0) output_tensor = core.get_fetch_variable(scope, "feed", 0)
output_lod = output_tensor.lod() output_lod = output_tensor.lod()
self.assertEqual(0, output_lod[0][0]) self.assertEqual(0, output_lod[0][0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册