未验证 提交 97af8516 编写于 作者: Z zhangkaihuo 提交者: GitHub

gradient add support SparseCooTensor (#43352)

上级 5752643b
......@@ -22,6 +22,7 @@
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace egr {
......@@ -49,6 +50,22 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
paddle::imperative::SelectedRowsAddTensor(*tensor, t, &new_buffer);
tensor->set_impl(new_buffer.impl());
}
} else if (LIKELY(t.is_sparse_coo_tensor())) {
// In fact, the gradient of SparseTensor is still a SparseTensor
if (LIKELY(tensor->is_sparse_coo_tensor())) {
auto t_sparse =
std::dynamic_pointer_cast<phi::SparseCooTensor>(t.impl());
paddle::experimental::Tensor t_values(
std::make_shared<phi::DenseTensor>(
t_sparse->non_zero_elements()));
auto tensor_sparse =
std::dynamic_pointer_cast<phi::SparseCooTensor>(tensor->impl());
paddle::experimental::Tensor tensor_values(
std::make_shared<phi::DenseTensor>(
tensor_sparse->non_zero_elements()));
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(
t_values, &tensor_values);
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace egr {
......@@ -130,6 +131,25 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
&new_buffer);
buffer_tensor.set_impl(new_buffer.impl());
}
} else if (t.is_sparse_coo_tensor()) {
auto t_sparse = std::dynamic_pointer_cast<phi::SparseCooTensor>(t.impl());
paddle::experimental::Tensor t_values(
std::make_shared<phi::DenseTensor>(t_sparse->non_zero_elements()));
// In fact, the gradient of SparseTensor is still a SparseTensor
if (buffer_tensor.is_sparse_coo_tensor()) {
auto buffer_sparse = std::dynamic_pointer_cast<phi::SparseCooTensor>(
buffer_tensor.impl());
paddle::experimental::Tensor buffer_values(
std::make_shared<phi::DenseTensor>(
buffer_sparse->non_zero_elements()));
if (create_graph) {
buffer_values =
add_final_state_dygraph_function(t_values, buffer_values);
} else {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(
t_values, &buffer_values);
}
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.incubate import sparse
from paddle.incubate.sparse import nn
from paddle.fluid.framework import _test_eager_guard
class TestGradientAdd(unittest.TestCase):
def sparse(self, sp_x):
indentity = sp_x
out = nn.functional.relu(sp_x)
values = out.values() + indentity.values()
out = sparse.sparse_coo_tensor(out.indices(),
values,
shape=out.shape,
stop_gradient=out.stop_gradient)
return out
def dense(self, x):
indentity = x
out = paddle.nn.functional.relu(x)
out = out + indentity
return out
def test(self):
with _test_eager_guard():
x = paddle.randn((3, 3))
sparse_x = x.to_sparse_coo(sparse_dim=2)
x.stop_gradient = False
sparse_x.stop_gradient = False
dense_out = self.dense(x)
loss = dense_out.mean()
loss.backward(retain_graph=True)
sparse_out = self.sparse(sparse_x)
sparse_loss = sparse_out.values().mean()
sparse_loss.backward(retain_graph=True)
assert np.allclose(dense_out.numpy(), sparse_out.to_dense().numpy())
assert np.allclose(loss.numpy(), loss.numpy())
assert np.allclose(x.grad.numpy(), sparse_x.grad.to_dense().numpy())
loss.backward()
sparse_loss.backward()
assert np.allclose(x.grad.numpy(), sparse_x.grad.to_dense().numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册