未验证 提交 8e7aa296 编写于 作者: C csy0225 提交者: GitHub

Fix inplace op dims not changed (#52416)

上级 273783b3
...@@ -2367,13 +2367,7 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -2367,13 +2367,7 @@ void OperatorWithKernel::TransferInplaceVarsBack(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The variable[%s] is nullptr.", var_name)); "The variable[%s] is nullptr.", var_name));
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto original_dims = original_tensor->dims();
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
// In order to solve the problem that the output latitude of NPU reshape
// operator is not changed when inplace.
if (type_ != "reshape2" && type_ != "reshape2_grad") {
original_tensor->Resize(original_dims);
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册