未验证 提交 85831c32 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] New IR access InterpreterCore:add local scope logic (#55112)

* add local scope

* refine code

* refien code

* refine code

* support local scope for BuildFuncList

* fix bug

* add log

* fix bug

* polish code

* fix bug
上级 902de74c
......@@ -940,6 +940,7 @@ void BuildOpFuncList(
::ir::Block* block,
std::vector<OpFuncNode>* vec_func_list,
framework::Scope* scope,
framework::Scope* local_scope,
const std::unordered_map<::ir::Value, std::string>& value_2_name_map,
const ExecutionConfig& execution_config) {
vec_func_list->reserve(block->size());
......@@ -979,6 +980,7 @@ void BuildOpFuncList(
false>((*it),
value_2_name_map,
scope,
local_scope,
op_yaml_info_parser,
&(op_func_node.infer_meta_context_));
......@@ -1004,6 +1006,7 @@ void BuildOpFuncList(
true>((*it),
value_2_name_map,
scope,
local_scope,
op_yaml_info_parser,
&(op_func_node.kernel_context_),
&(op_func_node.input_index),
......
......@@ -98,6 +98,7 @@ void BuildOpFuncList(
::ir::Block* block,
std::vector<OpFuncNode>* vec_func_list,
framework::Scope* scope,
framework::Scope* local_scope,
const std::unordered_map<::ir::Value, std::string>& value_2_name_map,
const ExecutionConfig& execution_config);
......
......@@ -86,6 +86,8 @@ class InterpreterBaseImpl {
virtual void reset_scope(Scope* new_scope) = 0;
virtual const Scope* local_scope() const = 0;
virtual const platform::Place& GetPlace() const = 0;
virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0;
......
......@@ -101,6 +101,10 @@ void InterpreterCore::reset_scope(Scope* new_scope) {
impl_->reset_scope(new_scope);
}
const Scope* InterpreterCore::local_scope() const {
return impl_->local_scope();
}
const platform::Place& InterpreterCore::GetPlace() const {
return impl_->GetPlace();
}
......
......@@ -65,6 +65,8 @@ class InterpreterCore {
void reset_scope(Scope* new_scope);
const Scope* local_scope() const;
const platform::Place& GetPlace() const;
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs);
......
......@@ -49,12 +49,14 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place,
stream_analyzer_(place),
execution_config_(execution_config),
var_scope_(scope),
scope_(scope),
ir_program_(std::move(ir_prog)) {
VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build &&
!FLAGS_new_executor_use_cuda_graph &&
!execution_config.used_for_control_flow_op;
// &&interpreter::BlockCanBeStaticBuilt(block);
static_build_ = true;
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
......@@ -62,21 +64,19 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place,
if (!FLAGS_new_executor_use_local_scope) {
execution_config_.create_local_scope = false;
}
execution_config_.AnalyzeThreadPoolConfig(place,
ir_program_->block()->size());
execution_config_.Log(/*log_level=*/8);
if (execution_config_.create_local_scope) {
auto local_scope = &var_scope_.GetMutableScope()->NewScope();
auto local_scope = &scope_->NewScope();
local_scope_ = local_scope;
VLOG(6) << "new ir interpretercore scope: " << scope_ << "\t"
<< "; local scope: " << local_scope_;
}
// force use outer scope for now
local_scope_ = scope;
static_build_ = true;
// TODO(zhangbo): delete var_scope
var_scope_.SetLocalScope(local_scope_);
execution_config_.AnalyzeThreadPoolConfig(place,
ir_program_->block()->size());
execution_config_.Log(/*log_level=*/8);
instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
SchedulingPriority lhs_scheduling_priority =
vec_instruction_[lhs].GetSchedulingPriority();
......@@ -185,12 +185,13 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
::ir::BuildScope(
ir_program_->block(), local_scope_, &value_2_var_name_map_);
ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_,
ir_program_->block(),
&op_func_nodes,
scope_,
local_scope_,
value_2_var_name_map_,
execution_config_);
......@@ -212,8 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
}
// return Fetch Tensors
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_;
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
......@@ -287,6 +287,8 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) {
}
}
const Scope* NewIRInterpreter::local_scope() const { return local_scope_; }
void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
async_work_queue_ = reinterpret_cast<NewIRInterpreter*>(src)->GetWorkQueue();
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src
......@@ -321,8 +323,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> NewIRInterpreter::GetWorkQueue() {
}
void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_;
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
......@@ -349,8 +350,7 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
if (instr_node->OpBase()->Type() == "cinn_launch" ||
instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope
// in kernel
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
......@@ -380,8 +380,7 @@ void NewIRInterpreter::BuildInplace() {
}
}
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
for (Instruction& instr : vec_instruction_) {
for (auto& item : instr.Inputs()) {
......@@ -799,8 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() {
void NewIRInterpreter::RunOperator(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
......@@ -1047,7 +1045,7 @@ void NewIRInterpreter::ExecuteInstructionList(
if (cancel) {
break;
}
VLOG(0) << "deps:\n" << GetDepsString();
VLOG(6) << "deps:\n" << GetDepsString();
times++;
}
return times;
......
......@@ -60,6 +60,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void reset_scope(Scope* new_scope) override;
const Scope* local_scope() const override;
const platform::Place& GetPlace() const override { return place_; }
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
......@@ -143,6 +145,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
ExecutionConfig execution_config_;
VariableScope var_scope_;
Scope* scope_{nullptr};
Scope* local_scope_{nullptr}; // not owned
EventsWaiter main_thread_blocker_;
......
......@@ -275,6 +275,7 @@ void ProgramInterpreter::reset_scope(Scope* new_scope) {
}
}
const Scope* ProgramInterpreter::local_scope() const { return local_scope_; }
void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
async_work_queue_ =
reinterpret_cast<ProgramInterpreter*>(src)->GetWorkQueue();
......
......@@ -62,6 +62,8 @@ class ProgramInterpreter : public InterpreterBaseImpl {
void reset_scope(Scope* new_scope) override;
const Scope* local_scope() const override;
const platform::Place& GetPlace() const override { return place_; }
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
......
......@@ -56,7 +56,7 @@ class PhiKernelAdaptor {
void run_kernel_prog(ir::Program* program) {
auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map;
BuildScope(block, scope_, &name_map);
BuildScope(block, scope_, nullptr, &name_map);
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......@@ -87,7 +87,7 @@ class PhiKernelAdaptor {
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, op_yaml_info_parser, &ctx);
false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx);
......@@ -107,7 +107,7 @@ class PhiKernelAdaptor {
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
true>(
(*it), name_map, scope_, op_yaml_info_parser, &kernel_ctx);
(*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx);
kernel_fn(&kernel_ctx);
auto out_value = (*it)->result(0);
......
......@@ -43,114 +43,160 @@
namespace ir {
void BuildScope(ir::Block* block,
paddle::framework::Scope* scope,
std::unordered_map<ir::Value, std::string>* name_map) {
std::unordered_map<ir::Value, int> map_test;
int count = name_map->size();
for (auto it = block->begin(); it != block->end(); ++it) {
size_t input_num = (*it)->num_operands();
auto attr_map = (*it)->attributes();
std::string op_name = (*it)->name();
if (attr_map.count("op_name")) {
op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data();
}
if (op_name == "pd.fetch") {
// fetch is a very special op, with no output
for (size_t i = 0; i < input_num; ++i) {
auto var = scope->Var("fetch");
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
int index =
(*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
fetch_list->resize(index + 1);
}
continue;
paddle::framework::Variable* CreateVar(ir::Value value,
std::string name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope) {
Operation* def_op = value.GetDefiningOp();
bool is_persisable = false;
if (def_op->attributes().count("is_persisable")) {
is_persisable = def_op->attributes()
.at("is_persisable")
.dyn_cast<ir::BoolAttribute>()
.data();
}
if (is_persisable) {
const paddle::framework::Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
VLOG(6) << "Create var: " << name << " in scope " << ancestor_scope;
return const_cast<paddle::framework::Scope*>(ancestor_scope)->Var(name);
} else {
VLOG(6) << "Create var: " << name << " in scope " << local_scope;
return local_scope->Var(name);
}
}
if (op_name == "builtin.set_parameter") {
auto param_name = (*it)
->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto in_ptr = (*it)->operand(0);
// change opreand name to param_name
auto orig_name = name_map->at(in_ptr);
(*name_map)[in_ptr] = param_name;
scope->Rename(orig_name, param_name);
continue;
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count) { // NOLINT
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
size_t input_num = op->num_operands();
if (op_name == "pd.fetch") {
// fetch is a very special op, with no output
VLOG(6) << "Handle for pd.fetch:";
for (size_t i = 0; i < input_num; ++i) {
auto var = scope->Var("fetch");
VLOG(6) << "Create var: fetch in scope " << scope;
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
fetch_list->resize(index + 1);
}
}
if (op_name == "builtin.get_parameter") {
auto param_name = (*it)
->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto out_ptr = (*it)->result(0);
if (op_name == "pd.feed") {
VLOG(6) << "Handle for pd.feed:";
auto ptr = op->result(0);
std::string name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name);
auto var = CreateVar(ptr, name, scope, local_scope);
// TODO(phlrain): need to update here, support StringTensor
auto out_tensor = var->GetMutable<phi::DenseTensor>();
auto feed_var = scope->Var("feed");
VLOG(6) << "Create var: feed in scope " << scope;
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
out_tensor->ShareDataWith(in_tensor);
}
name_map->emplace(out_ptr, param_name);
continue;
if (op_name == "builtin.combine") {
VLOG(6) << "Handle for builtin.combine:";
auto out_value = op->result(0);
std::string name;
if (name_map->find(out_value) != name_map->end()) {
name = name_map->at(out_value);
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(out_value, name);
}
if (op_name == "pd.feed") {
auto ptr = (*it)->result(0);
std::string name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name);
auto var = scope->Var(name);
// TODO(phlrain): need to update here, support StringTensor
auto out_tensor = var->GetMutable<phi::DenseTensor>();
auto feed_var = scope->Var("feed");
int index =
(*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
auto var = CreateVar(out_value, name, scope, local_scope);
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
out_tensor->ShareDataWith(in_tensor);
for (size_t i = 0; i < input_num; ++i) {
auto ptr = op->operand(i);
continue;
PADDLE_ENFORCE_EQ(
name_map->count(ptr),
true,
phi::errors::PreconditionNotMet("can not found input of combine op"));
tensor_array->emplace_back(
&(CreateVar(ptr, name_map->at(ptr), scope, local_scope)
->Get<phi::DenseTensor>()));
}
}
if (op_name == "builtin.combine") {
auto out_value = (*it)->result(0);
if (op_name == "builtin.set_parameter") {
VLOG(6) << "Handle for builtin.set_parameter:";
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
VLOG(5) << "process builtin combine";
std::string name;
if (name_map->find(out_value) != name_map->end()) {
name = name_map->at(out_value);
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(out_value, name);
}
auto in_ptr = op->operand(0);
// change opreand name to param_name
auto var = scope->Var(name);
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
auto orig_name = name_map->at(in_ptr);
(*name_map)[in_ptr] = param_name;
scope->Rename(orig_name, param_name);
}
for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i);
if (op_name == "builtin.get_parameter") {
VLOG(6) << "Handle for builtin.get_parameter:";
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto out_ptr = op->result(0);
name_map->emplace(out_ptr, param_name);
}
}
void BuildScope(ir::Block* block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map) {
// NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable
// is created in scope , and other is created in local_scope.
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
VLOG(6) << "Build: scope [" << scope << "] inner_local_scope ["
<< inner_local_scope << "]";
PADDLE_ENFORCE_EQ(name_map->count(ptr),
true,
phi::errors::PreconditionNotMet(
"can not found input of combine op"));
// int count = name_map->size();
int count = name_map->size();
for (auto it = block->begin(); it != block->end(); ++it) {
ir::Operation* op = *it;
tensor_array->emplace_back(
&(scope->Var(name_map->at(ptr))->Get<phi::DenseTensor>()));
}
auto attr_map = op->attributes();
std::string op_name = op->name();
if (attr_map.count("op_name")) {
op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data();
}
if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(6) << "HandleForSpecialOp: " << op_name;
HandleForSpecialOp(op, scope, inner_local_scope, name_map, count);
continue;
}
// TODO(zhangbo): support builtin.slice
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i);
auto ptr = op->operand(i);
if (ptr) {
PADDLE_ENFORCE_NE(
name_map->find(ptr),
......@@ -163,11 +209,10 @@ void BuildScope(ir::Block* block,
}
}
int out_num = (*it)->num_results();
int out_num = op->num_results();
if (out_num > 0) {
for (int i = 0; i < out_num; ++i) {
ir::Value ptr = (*it)->result(i);
ir::Value ptr = op->result(i);
std::string name;
if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr);
......@@ -175,7 +220,7 @@ void BuildScope(ir::Block* block,
name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name);
}
auto var = scope->Var(name);
auto var = CreateVar(ptr, name, scope, inner_local_scope);
// Only support DenseTensor or Vector<DenseTensor>
if (!ptr.type()) {
var->GetMutable<phi::DenseTensor>();
......@@ -195,7 +240,7 @@ void BuildScope(ir::Block* block,
"Element of VectorType output only support "
"DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++);
auto var_i = scope->Var(name_i);
auto var_i = CreateVar(ptr, name_i, scope, inner_local_scope);
tensor_array->emplace_back(var_i->GetMutable<phi::DenseTensor>());
}
} else {
......
......@@ -25,6 +25,7 @@
#include "paddle/ir/core/utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
......@@ -39,9 +40,20 @@
#include "glog/logging.h"
namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value,
std::string name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope);
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void BuildScope(ir::Block* block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map);
template <typename Context,
......@@ -53,10 +65,15 @@ void BuildPhiContext(
ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
const paddle::dialect::OpYamlInfoParser& op_yaml_info,
Context* ctx,
std::map<std::string, std::vector<int>>* input_map = nullptr,
std::map<std::string, std::vector<int>>* output_map = nullptr) {
paddle::framework::Scope* inner_scope =
local_scope != nullptr ? local_scope : scope;
VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope["
<< inner_scope << "]";
// inputs include input and mutable attributes
auto attr_map = op->attributes();
......@@ -80,11 +97,10 @@ void BuildPhiContext(
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(scope->FindLocalVar(in_var_name),
PADDLE_ENFORCE_NOT_NULL(inner_scope->FindLocalVar(in_var_name),
phi::errors::PreconditionNotMet(
"can not find var[%s] in scope", in_var_name));
auto var = scope->Var(in_var_name);
auto var = inner_scope->FindVar(in_var_name);
if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(InType(tensor_in));
......@@ -123,12 +139,12 @@ void BuildPhiContext(
auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
phi::Attribute r1 = phi::TensorRef(
&(inner_scope->FindVar(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 =
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>()));
phi::Attribute r1 = phi::TensorRef(
&(inner_scope->FindVar(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
} else {
......@@ -239,7 +255,7 @@ void BuildPhiContext(
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = scope->Var("fetch");
auto fetch_var = inner_scope->FindVar("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
......@@ -251,7 +267,7 @@ void BuildPhiContext(
auto name = name_map.at(out_ptr);
if (out_ptr.type()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>()))));
&(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
} else {
phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr);
......
......@@ -73,7 +73,11 @@ TEST(StandaloneExecutor, run) {
test_core.Run({});
auto out_tensor = scope.Var("inner_var_2")->Get<phi::DenseTensor>();
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_2")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
......@@ -142,7 +146,11 @@ TEST(StandaloneExecutor, run_2) {
test_core.Run({});
auto out_tensor = scope.Var("inner_var_10")->Get<phi::DenseTensor>();
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_10")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_10")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.80721);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.70047);
......@@ -213,7 +221,11 @@ TEST(StandaloneExecutor, data_transfer) {
test_core.Run({});
auto out_tensor = scope.Var("inner_var_9")->Get<phi::DenseTensor>();
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_9")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_9")
->Get<phi::DenseTensor>();
auto& pool = phi::DeviceContextPool::Instance();
phi::DenseTensor out;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册