test_sparse_softmax_op.py 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.framework import _test_eager_guard

import numpy as np
import scipy.sparse as sp
import unittest

np.random.seed(2022)


class TestCsrSoftmax(unittest.TestCase):

27
    def test_softmax2d(self):
28
        with _test_eager_guard():
29 30
            mask = np.random.rand(16, 128) < 0.5
            np_x = np.random.rand(16, 128) * mask
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
            np_csr = sp.csr_matrix(np_x)

            row_number = np_csr.shape[0]
            np_out = np.array([])
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                x = np_csr.data[start:end]
                x_max = np.max(x, keepdims=True)
                x_exp = np.exp(x - x_max)
                x_exp_sum = np.sum(x_exp, keepdims=True)
                np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
47
            m = paddle.sparse.nn.Softmax()
48
            out = m(csr)
49 50 51 52 53 54 55
            np.testing.assert_allclose(out.crows().numpy(),
                                       np_csr.indptr,
                                       rtol=1e-05)
            np.testing.assert_allclose(out.cols().numpy(),
                                       np_csr.indices,
                                       rtol=1e-05)
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
56 57 58

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
59
            dx = np.array([])
60 61 62 63 64 65 66 67
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                out = np_out[start:end]
                dout = np_csr.data[start:end]
                sum = np.sum(dout * out, keepdims=True)
68
                dx = np.concatenate([dx, (dout - sum) * out])
69

70 71 72 73 74 75 76 77 78
            np.testing.assert_allclose(csr.grad.crows().numpy(),
                                       np_csr.indptr,
                                       rtol=1e-05)
            np.testing.assert_allclose(csr.grad.cols().numpy(),
                                       np_csr.indices,
                                       rtol=1e-05)
            np.testing.assert_allclose(csr.grad.values().numpy(),
                                       dx,
                                       rtol=1e-05)
79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    def test_softmax3d(self):
        with _test_eager_guard():
            batchNum = 16
            mask = np.random.rand(batchNum, 16, 128) < 0.5
            np_x = np.random.rand(batchNum, 16, 128) * mask

            np_out_list = []
            np_out = np.array([])
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
                for j in range(row_number, ):
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    x = np_csr.data[start:end]
                    x_max = np.max(x, keepdims=True)
                    x_exp = np.exp(x - x_max)
                    x_exp_sum = np.sum(x_exp, keepdims=True)
                    np_out_list.append(x_exp / x_exp_sum)
                    np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
104
            m = paddle.sparse.nn.Softmax()
105
            out = m(csr)
106
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
            dx = np.array([])
            batch_offset = 0
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
                for j in range(row_number):
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    dout = np_csr.data[start:end]
                    out = np_out[batch_offset + start:batch_offset + end]
                    sum = np.sum(dout * out, keepdims=True)
                    dx = np.concatenate([dx, (dout - sum) * out])

                batch_offset += np_csr.nnz

127 128 129
            np.testing.assert_allclose(csr.grad.values().numpy(),
                                       dx,
                                       rtol=1e-05)
130

131 132 133

if __name__ == "__main__":
    unittest.main()