未验证 提交 7ce0f9e1 编写于 作者: Z zhangbo9674 提交者: GitHub

fix bug (#55837)

上级 3f630658
......@@ -206,18 +206,20 @@ PhiKernelInstruction::PhiKernelInstruction(
yaml_interface->get_op_info_());
VLOG(6) << "finish process yaml_info_parser";
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>(op,
value_2_var_name,
scope,
local_scope,
yaml_info_parser,
&infer_meta_context_);
if (infer_meta_interface_) {
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>(op,
value_2_var_name,
scope,
local_scope,
yaml_info_parser,
&infer_meta_context_);
}
VLOG(6) << "finish process infer meta context";
auto kernel_name =
......@@ -343,7 +345,9 @@ void PhiKernelInstruction::InitInputsOutputsIds(
}
void PhiKernelInstruction::Run() {
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
if (infer_meta_interface_) {
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
}
VLOG(6) << "Run op " << phi_op_name_ << " infer meta.";
(*(phi_kernel_))(&(kernel_context_));
VLOG(6) << "Run op " << phi_op_name_ << " kernel.";
......
......@@ -46,6 +46,49 @@
namespace ir {
void AddNewData(ir::Value value,
std::string name,
paddle::framework::Variable* var,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
}
void RenameData(ir::Value value,
std::string new_name,
std::string orig_name,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id) {
(*value_2_var_name)[value] = new_name;
for (auto kv : (*variable_2_var_name)) {
if (kv.second == orig_name) {
(*variable_2_var_name)[kv.first] = new_name;
}
}
for (auto kv : *(var_name_2_id)) {
if (kv.first == orig_name) {
var_name_2_id->emplace(new_name, kv.second);
}
}
var_name_2_id->erase(orig_name);
}
using VariableNameMap =
std::unordered_map<const paddle::framework::Variable*, std::string>;
......@@ -80,16 +123,13 @@ paddle::framework::Variable* CreateVar(
VLOG(6) << "Create var: " << name << " in scope " << inner_scope;
var = inner_scope->Var(name);
}
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
AddNewData(value,
name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
return var;
}
......@@ -207,16 +247,13 @@ void HandleForSpecialOp(
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", name));
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
AddNewData(value,
name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
if (op_name == "pd.feed_with_place") {
......@@ -225,22 +262,18 @@ void HandleForSpecialOp(
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->result(0);
value_2_var_name->emplace(value, var_name);
paddle::framework::Variable* var = inner_scope->FindVar(var_name);
PADDLE_ENFORCE(var,
paddle::platform::errors::InvalidArgument(
"The variable %s shoud exist", var_name));
variable_2_var_name->emplace(var, var_name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(var_name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
AddNewData(value,
var_name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
if (op_name == "builtin.combine") {
......@@ -290,7 +323,12 @@ void HandleForSpecialOp(
const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, param_name);
}
(*value_2_var_name)[value] = param_name;
RenameData(value,
param_name,
orig_name,
value_2_var_name,
variable_2_var_name,
var_name_2_id);
}
if (op_name == "pd.shadow_output") {
......@@ -306,7 +344,12 @@ void HandleForSpecialOp(
const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, var_name);
}
(*value_2_var_name)[value] = var_name;
RenameData(value,
var_name,
orig_name,
value_2_var_name,
variable_2_var_name,
var_name_2_id);
}
if (op_name == "builtin.get_parameter") {
......@@ -316,7 +359,14 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>()
.AsString();
auto value = op->result(0);
value_2_var_name->emplace(value, param_name);
paddle::framework::Variable* var = inner_scope->FindVar(param_name);
AddNewData(value,
param_name,
var,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
if (op_name == "builtin.slice") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册