未验证 提交 7ce0f9e1 编写于 作者: Z zhangbo9674 提交者: GitHub

fix bug (#55837)

上级 3f630658
...@@ -206,6 +206,7 @@ PhiKernelInstruction::PhiKernelInstruction( ...@@ -206,6 +206,7 @@ PhiKernelInstruction::PhiKernelInstruction(
yaml_interface->get_op_info_()); yaml_interface->get_op_info_());
VLOG(6) << "finish process yaml_info_parser"; VLOG(6) << "finish process yaml_info_parser";
if (infer_meta_interface_) {
::ir::BuildPhiContext< ::ir::BuildPhiContext<
phi::InferMetaContext, phi::InferMetaContext,
phi::MetaTensor, phi::MetaTensor,
...@@ -218,6 +219,7 @@ PhiKernelInstruction::PhiKernelInstruction( ...@@ -218,6 +219,7 @@ PhiKernelInstruction::PhiKernelInstruction(
local_scope, local_scope,
yaml_info_parser, yaml_info_parser,
&infer_meta_context_); &infer_meta_context_);
}
VLOG(6) << "finish process infer meta context"; VLOG(6) << "finish process infer meta context";
auto kernel_name = auto kernel_name =
...@@ -343,7 +345,9 @@ void PhiKernelInstruction::InitInputsOutputsIds( ...@@ -343,7 +345,9 @@ void PhiKernelInstruction::InitInputsOutputsIds(
} }
void PhiKernelInstruction::Run() { void PhiKernelInstruction::Run() {
if (infer_meta_interface_) {
infer_meta_interface_->infer_meta_(&(infer_meta_context_)); infer_meta_interface_->infer_meta_(&(infer_meta_context_));
}
VLOG(6) << "Run op " << phi_op_name_ << " infer meta."; VLOG(6) << "Run op " << phi_op_name_ << " infer meta.";
(*(phi_kernel_))(&(kernel_context_)); (*(phi_kernel_))(&(kernel_context_));
VLOG(6) << "Run op " << phi_op_name_ << " kernel."; VLOG(6) << "Run op " << phi_op_name_ << " kernel.";
......
...@@ -46,6 +46,49 @@ ...@@ -46,6 +46,49 @@
namespace ir { namespace ir {
void AddNewData(ir::Value value,
std::string name,
paddle::framework::Variable* var,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
}
void RenameData(ir::Value value,
std::string new_name,
std::string orig_name,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id) {
(*value_2_var_name)[value] = new_name;
for (auto kv : (*variable_2_var_name)) {
if (kv.second == orig_name) {
(*variable_2_var_name)[kv.first] = new_name;
}
}
for (auto kv : *(var_name_2_id)) {
if (kv.first == orig_name) {
var_name_2_id->emplace(new_name, kv.second);
}
}
var_name_2_id->erase(orig_name);
}
using VariableNameMap = using VariableNameMap =
std::unordered_map<const paddle::framework::Variable*, std::string>; std::unordered_map<const paddle::framework::Variable*, std::string>;
...@@ -80,16 +123,13 @@ paddle::framework::Variable* CreateVar( ...@@ -80,16 +123,13 @@ paddle::framework::Variable* CreateVar(
VLOG(6) << "Create var: " << name << " in scope " << inner_scope; VLOG(6) << "Create var: " << name << " in scope " << inner_scope;
var = inner_scope->Var(name); var = inner_scope->Var(name);
} }
value_2_var_name->emplace(value, name); AddNewData(value,
variable_2_var_name->emplace(var, name); name,
auto id = var_name_2_id->size(); var,
var_name_2_id->emplace(name, id); value_2_var_name,
variable_list->push_back(var); variable_2_var_name,
PADDLE_ENFORCE_EQ( var_name_2_id,
variable_list->size(), variable_list);
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
return var; return var;
} }
...@@ -207,16 +247,13 @@ void HandleForSpecialOp( ...@@ -207,16 +247,13 @@ void HandleForSpecialOp(
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", name)); "The variable %s shoud exist", name));
value_2_var_name->emplace(value, name); AddNewData(value,
variable_2_var_name->emplace(var, name); name,
auto id = var_name_2_id->size(); var,
var_name_2_id->emplace(name, id); value_2_var_name,
variable_list->push_back(var); variable_2_var_name,
PADDLE_ENFORCE_EQ( var_name_2_id,
variable_list->size(), variable_list);
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
} }
if (op_name == "pd.feed_with_place") { if (op_name == "pd.feed_with_place") {
...@@ -225,22 +262,18 @@ void HandleForSpecialOp( ...@@ -225,22 +262,18 @@ void HandleForSpecialOp(
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString(); op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->result(0); auto value = op->result(0);
value_2_var_name->emplace(value, var_name);
paddle::framework::Variable* var = inner_scope->FindVar(var_name); paddle::framework::Variable* var = inner_scope->FindVar(var_name);
PADDLE_ENFORCE(var, PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", var_name)); "The variable %s shoud exist", var_name));
variable_2_var_name->emplace(var, var_name); AddNewData(value,
auto id = var_name_2_id->size(); var_name,
var_name_2_id->emplace(var_name, id); var,
variable_list->push_back(var); value_2_var_name,
PADDLE_ENFORCE_EQ( variable_2_var_name,
variable_list->size(), var_name_2_id,
var_name_2_id->size(), variable_list);
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
} }
if (op_name == "builtin.combine") { if (op_name == "builtin.combine") {
...@@ -290,7 +323,12 @@ void HandleForSpecialOp( ...@@ -290,7 +323,12 @@ void HandleForSpecialOp(
const_cast<paddle::framework::Scope*>(inner_scope->root()) const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, param_name); ->Rename(orig_name, param_name);
} }
(*value_2_var_name)[value] = param_name; RenameData(value,
param_name,
orig_name,
value_2_var_name,
variable_2_var_name,
var_name_2_id);
} }
if (op_name == "pd.shadow_output") { if (op_name == "pd.shadow_output") {
...@@ -306,7 +344,12 @@ void HandleForSpecialOp( ...@@ -306,7 +344,12 @@ void HandleForSpecialOp(
const_cast<paddle::framework::Scope*>(inner_scope->root()) const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, var_name); ->Rename(orig_name, var_name);
} }
(*value_2_var_name)[value] = var_name; RenameData(value,
var_name,
orig_name,
value_2_var_name,
variable_2_var_name,
var_name_2_id);
} }
if (op_name == "builtin.get_parameter") { if (op_name == "builtin.get_parameter") {
...@@ -316,7 +359,14 @@ void HandleForSpecialOp( ...@@ -316,7 +359,14 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.AsString(); .AsString();
auto value = op->result(0); auto value = op->result(0);
value_2_var_name->emplace(value, param_name); paddle::framework::Variable* var = inner_scope->FindVar(param_name);
AddNewData(value,
param_name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
} }
if (op_name == "builtin.slice") { if (op_name == "builtin.slice") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册