未验证 提交 f00a06d8 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine BuildScope in phi_kernel_util (#55423)

* add code

* fix bug

* refine code

* refine code

* fix bug
上级 7f6d222f
......@@ -185,8 +185,12 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
::ir::BuildScope(
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
::ir::BuildScope(*ir_program_->block(),
InnerScope(),
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_,
......@@ -194,7 +198,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&op_func_nodes,
scope_,
local_scope_,
value_2_var_name_map_,
value_2_var_name_,
execution_config_);
// SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
......@@ -237,8 +241,12 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
SetDeviceId(place_);
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is BetaRunning.";
::ir::BuildScope(
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
::ir::BuildScope(*ir_program_->block(),
InnerScope(),
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_);
BuildInstruction();
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size();
++instr_id) {
......@@ -1526,13 +1534,8 @@ void NewIRInterpreter::BuildInstruction() {
++it) {
VLOG(0) << "Build Instruction for op: " << op_idx;
if ((*it)->dialect()->name() == "pd_kernel") {
vec_instruction_base_.emplace_back(
std::make_unique<PhiKernelInstruction>(op_idx++,
place_,
(*it),
scope_,
local_scope_,
value_2_var_name_map_));
vec_instruction_base_.emplace_back(std::make_unique<PhiKernelInstruction>(
op_idx++, place_, (*it), scope_, local_scope_, value_2_var_name_));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Now only support pd_kernel dialect."));
......
......@@ -192,7 +192,11 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;
std::unordered_map<::ir::Value, std::string> value_2_var_name_map_;
std::unordered_map<::ir::Value, std::string> value_2_var_name_;
std::unordered_map<const paddle::framework::Variable*, std::string>
variable_2_var_name_;
std::map<std::string, int> var_name_2_id_;
std::vector<Variable*> variable_list_;
};
} // namespace framework
......
......@@ -55,8 +55,18 @@ class PhiKernelAdaptor {
void run_kernel_prog(ir::Program* program) {
auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map;
BuildScope(*block, scope_, nullptr, &name_map);
std::unordered_map<ir::Value, std::string> value_2_var_name;
std::unordered_map<const paddle::framework::Variable*, std::string>
variable_2_var_name;
std::map<std::string, int> var_name_2_id;
std::vector<paddle::framework::Variable*> variable_list;
BuildScope(*block,
scope_,
&value_2_var_name,
&variable_2_var_name,
&var_name_2_id,
&variable_list);
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......@@ -88,7 +98,8 @@ class PhiKernelAdaptor {
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx);
false>(
(*it), value_2_var_name, scope_, nullptr, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx);
......@@ -108,12 +119,16 @@ class PhiKernelAdaptor {
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
paddle::small_vector<phi::TensorBase*>,
true>(
(*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx);
true>((*it),
value_2_var_name,
scope_,
nullptr,
op_yaml_info_parser,
&kernel_ctx);
kernel_fn(&kernel_ctx);
auto out_value = (*it)->result(0);
out_name = name_map[out_value];
out_name = value_2_var_name[out_value];
}
}
......
......@@ -46,10 +46,15 @@ namespace ir {
using VariableNameMap =
std::unordered_map<const paddle::framework::Variable*, std::string>;
paddle::framework::Variable* CreateVar(ir::Value value,
const std::string& name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope) {
paddle::framework::Variable* CreateVar(
ir::Value value,
paddle::framework::Scope* inner_scope,
bool force_persisable,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
Operation* def_op = value.GetDefiningOp();
bool is_persisable = false;
if (def_op->attributes().count("is_persisable")) {
......@@ -58,27 +63,41 @@ paddle::framework::Variable* CreateVar(ir::Value value,
.dyn_cast<ir::BoolAttribute>()
.data();
}
if (is_persisable) {
VLOG(6) << "Create var: " << name << " in scope " << scope->root();
return const_cast<paddle::framework::Scope*>(scope->root())->Var(name);
paddle::framework::Variable* var = nullptr;
std::string name = "inner_var_" + std::to_string(variable_2_var_name->size());
if (force_persisable || is_persisable) {
VLOG(6) << "Create var: " << name << " in scope " << inner_scope->root();
var = const_cast<paddle::framework::Scope*>(inner_scope->root())->Var(name);
} else {
VLOG(6) << "Create var: " << name << " in scope " << local_scope;
return local_scope->Var(name);
VLOG(6) << "Create var: " << name << " in scope " << inner_scope;
var = inner_scope->Var(name);
}
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
return var;
}
void CheckInputVars(
ir::Operation* op,
const std::string& op_name,
const std::unordered_map<ir::Value, std::string>& name_map) {
const std::unordered_map<ir::Value, std::string>& value_2_var_name) {
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
if (value) {
PADDLE_ENFORCE_NE(
name_map.find(value),
name_map.end(),
value_2_var_name.find(value),
value_2_var_name.end(),
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
......@@ -89,20 +108,25 @@ void CheckInputVars(
}
void BuildValue(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
VariableNameMap* variable_name_map,
int& count) { // NOLINT
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
std::string name;
if (name_map->find(value) != name_map->end()) {
name = name_map->at(value);
paddle::framework::Scope* inner_scope,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
paddle::framework::Variable* var = nullptr;
if (value_2_var_name->find(value) != value_2_var_name->end()) {
var = inner_scope->FindVar(value_2_var_name->at(value));
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(value, name);
var = CreateVar(value,
inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
auto var = CreateVar(value, name, scope, inner_local_scope);
// Only support DenseTensor or Vector<DenseTensor>
if (!value.type()) {
var->GetMutable<phi::DenseTensor>();
......@@ -120,11 +144,15 @@ void BuildValue(ir::Value value,
paddle::platform::errors::Fatal(
"Element of VectorType output only support "
"DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++);
auto var_i = CreateVar(value, name_i, scope, inner_local_scope);
auto var_i = CreateVar(value,
inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
var_i->GetMutable<phi::DenseTensor>();
tensor_array->emplace_back(var_i);
variable_name_map->emplace(var_i, name_i);
}
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
......@@ -132,24 +160,25 @@ void BuildValue(ir::Value value,
}
}
void HandleForSpecialOp(ir::Operation* op,
const VariableNameMap& variable_name_map,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count) { // NOLINT
void HandleForSpecialOp(
ir::Operation* op,
paddle::framework::Scope* inner_scope,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
size_t input_num = op->num_operands();
if (op_name == "pd.fetch") {
// fetch is a very special op, with no output
VLOG(6) << "Handle for pd.fetch:";
auto var = scope->Var("fetch");
VLOG(6) << "Create var: fetch in scope " << scope;
auto var = const_cast<paddle::framework::Scope*>(inner_scope->root())
->Var("fetch");
VLOG(6) << "Create var: fetch in scope " << inner_scope->root();
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
......@@ -157,16 +186,20 @@ void HandleForSpecialOp(ir::Operation* op,
}
if (op_name == "pd.feed") {
VLOG(6) << "Handle for pd.feed:";
auto value = op->result(0);
std::string name = "inner_var_" + std::to_string(count++);
name_map->emplace(value, name);
auto var = CreateVar(value, name, scope, local_scope);
auto var = CreateVar(value,
inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
// TODO(phlrain): need to update here, support StringTensor
auto out_tensor = var->GetMutable<phi::DenseTensor>();
auto feed_var = scope->Var("feed");
VLOG(6) << "Create var: feed in scope " << scope;
auto feed_var =
const_cast<paddle::framework::Scope*>(inner_scope->root())->Var("feed");
VLOG(6) << "Create var: feed in scope " << inner_scope->root();
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
......@@ -176,30 +209,33 @@ void HandleForSpecialOp(ir::Operation* op,
}
if (op_name == "builtin.combine") {
VLOG(6) << "Handle for builtin.combine:";
auto out_value = op->result(0);
std::string name;
if (name_map->find(out_value) != name_map->end()) {
name = name_map->at(out_value);
paddle::framework::Variable* var = nullptr;
if (value_2_var_name->find(out_value) != value_2_var_name->end()) {
var = inner_scope->FindVar(value_2_var_name->at(out_value));
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(out_value, name);
var = CreateVar(out_value,
inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
auto var = CreateVar(out_value, name, scope, local_scope);
auto tensor_array = var->GetMutable<paddle::framework::VariableRefArray>();
// clear tensor array
tensor_array->clear();
size_t input_num = op->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
PADDLE_ENFORCE_EQ(
name_map->count(value),
value_2_var_name->count(value),
true,
phi::errors::PreconditionNotMet("can not found input of combine op"));
tensor_array->emplace_back(
CreateVar(value, name_map->at(value), scope, local_scope));
inner_scope->FindVar(value_2_var_name->at(value)));
}
}
......@@ -210,14 +246,15 @@ void HandleForSpecialOp(ir::Operation* op,
.dyn_cast<ir::StrAttribute>()
.data();
auto in_ptr = op->operand(0);
auto value = op->operand(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);
auto orig_name = name_map->at(in_ptr);
if (scope->FindVar(param_name) == nullptr) {
scope->Rename(orig_name, param_name);
if (inner_scope->root()->FindVar(param_name) == nullptr) {
const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, param_name);
}
(*name_map)[in_ptr] = param_name;
(*value_2_var_name)[value] = param_name;
}
if (op_name == "builtin.get_parameter") {
......@@ -226,44 +263,44 @@ void HandleForSpecialOp(ir::Operation* op,
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto out_ptr = op->result(0);
name_map->emplace(out_ptr, param_name);
auto value = op->result(0);
value_2_var_name->emplace(value, param_name);
}
if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0);
auto in_value = op->operand(0);
PADDLE_ENFORCE_EQ(name_map->count(in_value),
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
true,
phi::errors::PreconditionNotMet(
"input of buildin slice not in name map"));
int index =
op->attributes().at("index").dyn_cast<ir::Int32Attribute>().data();
auto in_var = scope->FindVar(name_map->at(in_value));
auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value));
auto variable_array = in_var->Get<paddle::framework::VariableRefArray>();
PADDLE_ENFORCE_EQ(
variable_name_map.count(variable_array[index]),
variable_2_var_name->count(variable_array[index]),
true,
phi::errors::PreconditionNotMet("[%d] the variable in build slice "
"input MUST in variable name map",
index));
std::string var_name = variable_name_map.at(variable_array[index]);
name_map->emplace(out_value, var_name);
std::string var_name = variable_2_var_name->at(variable_array[index]);
value_2_var_name->emplace(out_value, var_name);
}
}
void HandleForInplaceOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count) { // NOLINT
void HandleForInplaceOp(
ir::Operation* op,
paddle::framework::Scope* inner_scope,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
if (op->num_results() < 1) return;
ir::IrContext* ctx = ir::IrContext::Instance();
std::string op_name = op->name();
......@@ -271,12 +308,12 @@ void HandleForInplaceOp(ir::Operation* op,
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
VLOG(4) << "HandleForInplaceOp: " << op_name;
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_());
VariableNameMap variable_name_map;
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value value = op->result(i);
std::string value_name = yaml_parser.OutputNames()[i];
......@@ -284,35 +321,36 @@ void HandleForInplaceOp(ir::Operation* op,
std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = name_map->at(inplace_value);
std::string var_name = value_2_var_name->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
name_map->emplace(value, var_name);
value_2_var_name->emplace(value, var_name);
} else {
BuildValue(
value, scope, local_scope, name_map, &variable_name_map, count);
BuildValue(value,
inner_scope,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
}
}
// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is
// created in inner_scope.
void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map) {
VLOG(4) << "***** [before build] scope: ******\n"
paddle::framework::Scope* inner_scope,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
// NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable
// is created in scope , and other is created in local_scope.
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
VLOG(6) << "Build: scope [" << scope << "] inner_local_scope ["
<< inner_local_scope << "]";
std::unordered_map<const paddle::framework::Variable*, std::string>
variable_name_map;
// int count = name_map->size();
int count = name_map->size();
const_cast<paddle::framework::Scope*>(inner_scope->root()));
// int count = value_2_var_name->size();
for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it;
......@@ -321,19 +359,21 @@ void BuildScope(const ir::Block& block,
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
VLOG(4) << "BuildScope for :" << op_name;
VLOG(4) << "build op:" << op_name;
if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice") {
VLOG(6) << "HandleForSpecialOp: " << op_name;
HandleForSpecialOp(
op, variable_name_map, scope, inner_local_scope, name_map, count);
HandleForSpecialOp(op,
inner_scope,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
continue;
}
CheckInputVars(op, op_name, *name_map);
CheckInputVars(op, op_name, *value_2_var_name);
if (op->num_results() < 1) continue;
if (op->attributes().count("is_inplace") != 0 &&
......@@ -341,23 +381,29 @@ void BuildScope(const ir::Block& block,
.at("is_inplace")
.dyn_cast<ir::BoolAttribute>()
.data()) {
HandleForInplaceOp(op, scope, inner_local_scope, name_map, count);
HandleForInplaceOp(op,
inner_scope,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
continue;
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
BuildValue(op->result(i),
scope,
local_scope,
name_map,
&variable_name_map,
count);
inner_scope,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
}
}
VLOG(4) << "***** [after build] scope: ******\n"
VLOG(4) << "***** [after build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
const_cast<paddle::framework::Scope*>(inner_scope->root()));
}
} // namespace ir
......@@ -41,36 +41,13 @@
#include "glog/logging.h"
namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value,
const std::string& name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope);
void BuildValue(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void HandleForInplaceOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void CheckInputVars(ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map);
void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map);
paddle::framework::Scope* inner_scope,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list);
template <typename Context,
typename InType,
......@@ -322,6 +299,7 @@ void BuildPhiContext(
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
auto out_type = out_ptr.type();
if (!out_type) {
phi::DenseTensor* ptr = nullptr;
......@@ -329,14 +307,14 @@ void BuildPhiContext(
ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(inner_scope->Var(name)->Get<phi::DenseTensor>()))));
&(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(inner_scope->Var(name)->Get<phi::SelectedRows>()))));
&(inner_scope->FindVar(name)->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) {
OutListType outputs;
auto& variable_array =
scope->Var(name)->Get<paddle::framework::VariableRefArray>();
scope->FindVar(name)->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>()))));
......@@ -360,6 +338,7 @@ void BuildPhiContext(
}
}
}
VLOG(6) << "Done build phi context";
}
} // namespace ir
......@@ -89,5 +89,42 @@ TEST(StandaloneExecutor, run) {
EXPECT_EQ(res3, true);
}
TEST(StandaloneExecutor, run_inplace_sqrt) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp full = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 4.0, phi::DataType::FLOAT32, phi::CPUPlace());
builder.Build<paddle::dialect::Sqrt_Op>(full->result(0));
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.BetaRun({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_0")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(scope.kids().size(), 1u);
EXPECT_EQ(scope.kids().front()->Size(), 1u);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册