提交 e7a9f6bb 编写于 作者: A Adam 提交者: Tao Luo

[Bugfix] Preserve shape in inpalce operators (#22360)

上级 89b54979
...@@ -1086,7 +1086,9 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -1086,7 +1086,9 @@ void OperatorWithKernel::TransferInplaceVarsBack(
PADDLE_ENFORCE_NOT_NULL(var, "The var[%s] should not be nullptr.", PADDLE_ENFORCE_NOT_NULL(var, "The var[%s] should not be nullptr.",
var_name); var_name);
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto original_dims = original_tensor->dims();
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
original_tensor->Resize(original_dims);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册