提交 91f5d2b9 编写于 作者: Q qijun

follow comments and create local_scope inside executor run method

上级 e8a678e1
...@@ -56,9 +56,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -56,9 +56,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
auto& block = pdesc.blocks(0); auto& block = pdesc.blocks(0);
auto& device = device_contexts_[0]; auto& device = device_contexts_[0];
// TODO(tonyyang-svail): Scope& local_scope = scope->NewScope();
// - runs on a new local scope
// Scope& local_scope = scope->NewScope();
for (auto& var : block.vars()) { for (auto& var : block.vars()) {
scope->NewVar(var.name()); scope->NewVar(var.name());
...@@ -67,7 +65,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -67,7 +65,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
for (auto& op_desc : block.ops()) { for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
std::cout << op->DebugString() << std::endl; std::cout << op->DebugString() << std::endl;
op->Run(*scope, *device); op->Run(local_scope, *device);
} }
// TODO(tonyyang-svail): need to test gpu device // TODO(tonyyang-svail): need to test gpu device
......
...@@ -131,7 +131,7 @@ template <typename T> ...@@ -131,7 +131,7 @@ template <typename T>
void set_feed_variable(const std::vector<std::vector<T>>& inputs) { void set_feed_variable(const std::vector<std::vector<T>>& inputs) {
typedef std::vector<paddle::framework::Tensor> FeedInputs; typedef std::vector<paddle::framework::Tensor> FeedInputs;
// Tensors in feed value variable will only be in CPUPlace // Tensors in feed value variable will only be in CPUPlace
Variable* g_feed_value = GetScope()->FindVar("feed_value"); Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value");
FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>()); FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>());
auto size = inputs.size(); auto size = inputs.size();
feed_inputs.resize(size); feed_inputs.resize(size);
...@@ -146,7 +146,7 @@ template <typename T> ...@@ -146,7 +146,7 @@ template <typename T>
std::vector<std::vector<T>> get_fetch_variable() { std::vector<std::vector<T>> get_fetch_variable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs; typedef std::vector<paddle::framework::Tensor> FetchOutputs;
// Tensors in fetch value variable will only be in CPUPlace // Tensors in fetch value variable will only be in CPUPlace
Variable* g_fetch_value = GetScope()->FindVar("fetch_value"); Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>()); FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>());
auto size = fetch_outputs.size(); auto size = fetch_outputs.size();
...@@ -252,7 +252,7 @@ TEST_F(ExecutorTesterRandom, CPU) { ...@@ -252,7 +252,7 @@ TEST_F(ExecutorTesterRandom, CPU) {
paddle::memory::Used(cpu_place); paddle::memory::Used(cpu_place);
Executor* executor = new Executor(places); Executor* executor = new Executor(places);
executor->Run(pdesc_, GetScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) { for (auto& vec : result) {
for (auto& num : vec) { for (auto& num : vec) {
...@@ -281,7 +281,7 @@ TEST_F(ExecutorTesterFeed, CPU) { ...@@ -281,7 +281,7 @@ TEST_F(ExecutorTesterFeed, CPU) {
// need to set feed variable before Executor::Run // need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl; std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_); set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) { for (auto& vec : result) {
for (auto& num : vec) { for (auto& num : vec) {
...@@ -309,7 +309,7 @@ TEST_F(ExecutorTesterRandom, GPU) { ...@@ -309,7 +309,7 @@ TEST_F(ExecutorTesterRandom, GPU) {
paddle::memory::Used(gpu_place); paddle::memory::Used(gpu_place);
Executor* executor = new Executor(places); Executor* executor = new Executor(places);
executor->Run(pdesc_, GetScope()); executor->Run(pdesc_, GetGlobalScope());
delete executor; delete executor;
} }
...@@ -333,7 +333,7 @@ TEST_F(ExecutorTesterFeed, GPU) { ...@@ -333,7 +333,7 @@ TEST_F(ExecutorTesterFeed, GPU) {
// need to set feed variable before Executor::Run // need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl; std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_); set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) { for (auto& vec : result) {
for (auto& num : vec) { for (auto& num : vec) {
......
...@@ -66,7 +66,7 @@ void Scope::DropKids() { ...@@ -66,7 +66,7 @@ void Scope::DropKids() {
std::once_flag feed_variable_flag; std::once_flag feed_variable_flag;
framework::Scope* GetScope() { framework::Scope* GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr}; static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() { std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope()); g_scope.reset(new framework::Scope());
......
...@@ -73,7 +73,7 @@ class Scope { ...@@ -73,7 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
}; };
framework::Scope* GetScope(); framework::Scope* GetGlobalScope();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,7 +27,7 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,7 @@ class FeedOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col"); int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetScope()->FindVar("feed_value"); framework::GetGlobalScope()->FindVar("feed_value");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>(); const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
......
...@@ -19,17 +19,15 @@ limitations under the License. */ ...@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class FeedKernel : public framework::OpKernel<T> { class FeedKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs; typedef std::vector<framework::Tensor> FeedInputs;
Tensor* out = ctx.Output<Tensor>("Out"); framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetScope()->FindVar("feed_value"); framework::GetGlobalScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>(); const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
out->CopyFrom<T>(tensors[col], ctx.GetPlace()); out->CopyFrom<T>(tensors[col], ctx.GetPlace());
......
...@@ -27,7 +27,7 @@ class FetchOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,7 @@ class FetchOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col"); int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetScope()->FindVar("fetch_value"); framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>(); FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
if (tensors->size() < static_cast<size_t>(col + 1)) { if (tensors->size() < static_cast<size_t>(col + 1)) {
......
...@@ -19,17 +19,15 @@ limitations under the License. */ ...@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class FetchKernel : public framework::OpKernel<T> { class FetchKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs; typedef std::vector<framework::Tensor> FetchOutputs;
const Tensor* input = ctx.Input<Tensor>("Input"); const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetScope()->FindVar("fetch_value"); framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>(); FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
(*tensors)[col].mutable_data<T>(platform::CPUPlace()); (*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace()); (*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册