未验证 提交 ac9c7e4d 编写于 作者: 张春乔 提交者: GitHub

fix var->IsType<phi::DenseTensor>() repeat judgment conditions (#49624)

上级 b0ece266
...@@ -63,7 +63,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) { ...@@ -63,7 +63,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) {
var, var,
platform::errors::PreconditionNotMet("The var should not be nullptr")); platform::errors::PreconditionNotMet("The var should not be nullptr"));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto *t = var->GetMutable<phi::DenseTensor>(); auto *t = var->GetMutable<phi::DenseTensor>();
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
...@@ -139,7 +139,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { ...@@ -139,7 +139,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
PADDLE_ENFORCE_NOT_NULL(var, PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The var should not be nullptr")); "The var should not be nullptr"));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto *t = var->GetMutable<phi::DenseTensor>(); auto *t = var->GetMutable<phi::DenseTensor>();
auto var_data_type = var_node->Var()->GetDataType(); auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is " VLOG(5) << "var_name is " << var_name << ", data type is "
...@@ -197,7 +197,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice( ...@@ -197,7 +197,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice(
var, var,
platform::errors::PreconditionNotMet("The var should not be nullptr")); platform::errors::PreconditionNotMet("The var should not be nullptr"));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto *t = var->GetMutable<phi::DenseTensor>(); auto *t = var->GetMutable<phi::DenseTensor>();
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册