提交 102a5f34 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/remove global scope (#4950)

* Unify `set_feed_variable` to one method

* Move global scope to python, not in C++
上级 9903e49f
......@@ -21,12 +21,12 @@ limitations under the License. */
namespace paddle {
namespace framework {
void SetFeedVariable(const LoDTensor& input, const std::string& var_name,
size_t index) {
void SetFeedVariable(Scope* scope, const LoDTensor& input,
const std::string& var_name, size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
// be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = GetGlobalScope().Var(var_name);
Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
if (index >= feed_inputs.size()) {
......@@ -38,10 +38,11 @@ void SetFeedVariable(const LoDTensor& input, const std::string& var_name,
feed_inputs[index].set_lod(input.lod());
}
LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) {
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index) {
// Since we want to fetch LodTensor from a variable, the variable must
// be created alreadly.
Variable* g_fetch_value = GetGlobalScope().FindVar(var_name);
Variable* g_fetch_value = scope.FindVar(var_name);
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
"Only %s can be invoked by GetFetchVariable",
typeid(FeedFetchList).name());
......
......@@ -72,13 +72,5 @@ void Scope::DeleteScope(Scope* scope) {
delete scope;
}
framework::Scope& GetGlobalScope() {
static framework::Scope* g_scope = nullptr;
if (g_scope == nullptr) {
g_scope = new framework::Scope();
}
return *g_scope;
}
} // namespace framework
} // namespace paddle
......@@ -74,8 +74,5 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope);
};
framework::Scope& GetGlobalScope();
} // namespace framework
} // namespace paddle
......@@ -219,8 +219,7 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
py::return_value_policy::reference)
.def("drop_kids", &Scope::DropKids)
.def_static("global_scope", &GetGlobalScope);
.def("drop_kids", &Scope::DropKids);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
......@@ -451,10 +450,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>())
.def("run",
[](Executor &self, ProgramDescBind *program_bind, int block_id) {
framework::Scope &global_scope = GetGlobalScope();
self.Run(*program_bind->Proto(), &global_scope, block_id);
.def("run", [](Executor &self, ProgramDescBind *program_bind,
Scope *scope, int block_id) {
self.Run(*program_bind->Proto(), scope, block_id);
});
m.def("unique_integer", UniqueIntegerGenerator);
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program
g_scope = core.Scope()
class Executor(object):
def __init__(self, places):
......@@ -20,10 +22,14 @@ class Executor(object):
feed,
fetch_list,
feed_var_name='feed',
fetch_var_name='fetch'):
fetch_var_name='fetch',
scope=None):
if not isinstance(program, Program):
raise TypeError()
if scope is None:
scope = g_scope
program = program.clone()
global_block = program.global_block()
feed_var = global_block.create_var(
......@@ -38,7 +44,7 @@ class Executor(object):
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
core.set_feed_variable(feed[name], feed_var.name, i)
core.set_feed_variable(scope, feed[name], feed_var.name, i)
fetch_var = global_block.create_var(
name=fetch_var_name,
......@@ -51,8 +57,8 @@ class Executor(object):
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program.desc, 0)
self.executor.run(program.desc, scope, 0)
return [
core.get_fetch_variable(fetch_var_name, i)
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
......@@ -5,6 +5,7 @@ import numpy as np
class TestFeedFetch(unittest.TestCase):
def test_feed_fetch(self):
scope = core.Scope()
place = core.CPUPlace()
input_array = np.ones((4, 4, 6)).astype("float32")
input_array[0, 0, 0] = 3
......@@ -12,9 +13,9 @@ class TestFeedFetch(unittest.TestCase):
input_tensor = core.LoDTensor([[0, 2, 4]])
input_tensor.set(input_array, place)
core.set_feed_variable(input_tensor, "feed", 0)
core.set_feed_variable(scope, input_tensor, "feed", 0)
output_tensor = core.get_fetch_variable("feed", 0)
output_tensor = core.get_fetch_variable(scope, "feed", 0)
output_lod = output_tensor.lod()
self.assertEqual(0, output_lod[0][0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册