未验证 提交 393db4a9 编写于 作者: R ronnywang 提交者: GitHub

[CustomDeivce] fix grad accumulation (#56052)

上级 6bd7f860
......@@ -46,7 +46,7 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
if (t.is_custom_device()) {
*tensor = add_ad_func(t, *tensor);
add__ad_func(*tensor, t);
} else {
paddle::imperative::TensorAdd<paddle::Tensor>(t, tensor);
}
......@@ -71,7 +71,7 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
paddle::Tensor tensor_values(std::make_shared<phi::DenseTensor>(
tensor_sparse->non_zero_elements()));
if (t.is_custom_device()) {
tensor_values = add_ad_func(t_values, tensor_values);
add__ad_func(tensor_values, t_values);
} else {
paddle::imperative::TensorAdd<paddle::Tensor>(t_values,
&tensor_values);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册