未验证 提交 393db4a9 编写于 作者: R ronnywang 提交者: GitHub

[CustomDeivce] fix grad accumulation (#56052)

上级 6bd7f860
...@@ -46,7 +46,7 @@ static void CopyOrAddTensor(paddle::Tensor* tensor, ...@@ -46,7 +46,7 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
if (LIKELY(t.is_dense_tensor())) { if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) { if (LIKELY(tensor->is_dense_tensor())) {
if (t.is_custom_device()) { if (t.is_custom_device()) {
*tensor = add_ad_func(t, *tensor); add__ad_func(*tensor, t);
} else { } else {
paddle::imperative::TensorAdd<paddle::Tensor>(t, tensor); paddle::imperative::TensorAdd<paddle::Tensor>(t, tensor);
} }
...@@ -71,7 +71,7 @@ static void CopyOrAddTensor(paddle::Tensor* tensor, ...@@ -71,7 +71,7 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
paddle::Tensor tensor_values(std::make_shared<phi::DenseTensor>( paddle::Tensor tensor_values(std::make_shared<phi::DenseTensor>(
tensor_sparse->non_zero_elements())); tensor_sparse->non_zero_elements()));
if (t.is_custom_device()) { if (t.is_custom_device()) {
tensor_values = add_ad_func(t_values, tensor_values); add__ad_func(tensor_values, t_values);
} else { } else {
paddle::imperative::TensorAdd<paddle::Tensor>(t_values, paddle::imperative::TensorAdd<paddle::Tensor>(t_values,
&tensor_values); &tensor_values);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册