未验证 提交 cfd6a8fc 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fix the new variable name in DataTransfer (#37756)

上级 00dfebe8
...@@ -137,7 +137,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, ...@@ -137,7 +137,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = *new_var_name =
var_name + "_layout_" + std::to_string(var_scope->VarSize() + 1); var_name + "_layout_" + std::to_string(var_scope->VarSize() + 1);
auto* ptr = local_scope->Var(new_var_name); auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
...@@ -171,8 +171,8 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, ...@@ -171,8 +171,8 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = *new_var_name =
var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1); var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1);
auto* ptr = local_scope->Var(new_var_name); auto* ptr = local_scope->Var(*new_var_name);
var_scope->SetVarDesc(var_name, nullptr);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
...@@ -211,7 +211,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -211,7 +211,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = *new_var_name =
var_name + "_device_" + std::to_string(var_scope->VarSize() + 1); var_name + "_device_" + std::to_string(var_scope->VarSize() + 1);
auto* ptr = local_scope->Var(new_var_name); auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册