提交 91f5d2b9 编写于 作者: Q qijun

follow comments and create local_scope inside executor run method

上级 e8a678e1
......@@ -56,9 +56,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
auto& block = pdesc.blocks(0);
auto& device = device_contexts_[0];
// TODO(tonyyang-svail):
// - runs on a new local scope
// Scope& local_scope = scope->NewScope();
Scope& local_scope = scope->NewScope();
for (auto& var : block.vars()) {
scope->NewVar(var.name());
......@@ -67,7 +65,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
std::cout << op->DebugString() << std::endl;
op->Run(*scope, *device);
op->Run(local_scope, *device);
}
// TODO(tonyyang-svail): need to test gpu device
......
......@@ -131,7 +131,7 @@ template <typename T>
void set_feed_variable(const std::vector<std::vector<T>>& inputs) {
typedef std::vector<paddle::framework::Tensor> FeedInputs;
// Tensors in feed value variable will only be in CPUPlace
Variable* g_feed_value = GetScope()->FindVar("feed_value");
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value");
FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>());
auto size = inputs.size();
feed_inputs.resize(size);
......@@ -146,7 +146,7 @@ template <typename T>
std::vector<std::vector<T>> get_fetch_variable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs;
// Tensors in fetch value variable will only be in CPUPlace
Variable* g_fetch_value = GetScope()->FindVar("fetch_value");
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>());
auto size = fetch_outputs.size();
......@@ -252,7 +252,7 @@ TEST_F(ExecutorTesterRandom, CPU) {
paddle::memory::Used(cpu_place);
Executor* executor = new Executor(places);
executor->Run(pdesc_, GetScope());
executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
......@@ -281,7 +281,7 @@ TEST_F(ExecutorTesterFeed, CPU) {
// need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope());
executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
......@@ -309,7 +309,7 @@ TEST_F(ExecutorTesterRandom, GPU) {
paddle::memory::Used(gpu_place);
Executor* executor = new Executor(places);
executor->Run(pdesc_, GetScope());
executor->Run(pdesc_, GetGlobalScope());
delete executor;
}
......@@ -333,7 +333,7 @@ TEST_F(ExecutorTesterFeed, GPU) {
// need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope());
executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
......
......@@ -66,7 +66,7 @@ void Scope::DropKids() {
std::once_flag feed_variable_flag;
framework::Scope* GetScope() {
framework::Scope* GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope());
......
......@@ -73,7 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope);
};
framework::Scope* GetScope();
framework::Scope* GetGlobalScope();
} // namespace framework
} // namespace paddle
......@@ -27,7 +27,7 @@ class FeedOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable =
framework::GetScope()->FindVar("feed_value");
framework::GetGlobalScope()->FindVar("feed_value");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
......
......@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class FeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
Tensor* out = ctx.Output<Tensor>("Out");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetScope()->FindVar("feed_value");
framework::GetGlobalScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
......
......@@ -27,7 +27,7 @@ class FetchOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable =
framework::GetScope()->FindVar("fetch_value");
framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
if (tensors->size() < static_cast<size_t>(col + 1)) {
......
......@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class FetchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
const Tensor* input = ctx.Input<Tensor>("Input");
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable =
framework::GetScope()->FindVar("fetch_value");
framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册