test_sparse_softmax_op.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16 17 18

import numpy as np
import scipy.sparse as sp
19 20

import paddle
21 22 23 24 25

np.random.seed(2022)


class TestCsrSoftmax(unittest.TestCase):
26
    def test_softmax2d(self):
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        mask = np.random.rand(16, 128) < 0.5
        np_x = np.random.rand(16, 128) * mask
        np_csr = sp.csr_matrix(np_x)

        row_number = np_csr.shape[0]
        np_out = np.array([])
        for i in range(row_number):
            start = np_csr.indptr[i]
            end = np_csr.indptr[i + 1]
            if start == end:
                continue
            x = np_csr.data[start:end]
            x_max = np.max(x, keepdims=True)
            x_exp = np.exp(x - x_max)
            x_exp_sum = np.sum(x_exp, keepdims=True)
            np_out = np.concatenate([np_out, x_exp / x_exp_sum])

        csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
        m = paddle.sparse.nn.Softmax()
        out = m(csr)
        np.testing.assert_allclose(
            out.crows().numpy(), np_csr.indptr, rtol=1e-05
        )
        np.testing.assert_allclose(
            out.cols().numpy(), np_csr.indices, rtol=1e-05
        )
        np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)

        # dx = (dout - sum(dout * out)) * out, dout=rand_x
        out.backward(csr.detach())
        dx = np.array([])
        for i in range(row_number):
            start = np_csr.indptr[i]
            end = np_csr.indptr[i + 1]
            if start == end:
                continue
            out = np_out[start:end]
            dout = np_csr.data[start:end]
            sum = np.sum(dout * out, keepdims=True)
            dx = np.concatenate([dx, (dout - sum) * out])

        np.testing.assert_allclose(
            csr.grad.crows().numpy(), np_csr.indptr, rtol=1e-05
        )
        np.testing.assert_allclose(
            csr.grad.cols().numpy(), np_csr.indices, rtol=1e-05
        )
        np.testing.assert_allclose(csr.grad.values().numpy(), dx, rtol=1e-05)
75

76 77 78 79 80 81 82 83 84
    def test_softmax3d(self):
        batchNum = 16
        mask = np.random.rand(batchNum, 16, 128) < 0.5
        np_x = np.random.rand(batchNum, 16, 128) * mask

        np_out_list = []
        np_out = np.array([])
        for i in range(batchNum):
            np_csr = sp.csr_matrix(np_x[i, :, :])
85
            row_number = np_csr.shape[0]
86 87 88 89 90
            for j in range(
                row_number,
            ):
                start = np_csr.indptr[j]
                end = np_csr.indptr[j + 1]
91 92 93 94 95 96
                if start == end:
                    continue
                x = np_csr.data[start:end]
                x_max = np.max(x, keepdims=True)
                x_exp = np.exp(x - x_max)
                x_exp_sum = np.sum(x_exp, keepdims=True)
97
                np_out_list.append(x_exp / x_exp_sum)
98 99
                np_out = np.concatenate([np_out, x_exp / x_exp_sum])

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
        m = paddle.sparse.nn.Softmax()
        out = m(csr)
        np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)

        # dx = (dout - sum(dout * out)) * out, dout=rand_x
        out.backward(csr.detach())
        dx = np.array([])
        batch_offset = 0
        for i in range(batchNum):
            np_csr = sp.csr_matrix(np_x[i, :, :])
            row_number = np_csr.shape[0]
            for j in range(row_number):
                start = np_csr.indptr[j]
                end = np_csr.indptr[j + 1]
115 116 117
                if start == end:
                    continue
                dout = np_csr.data[start:end]
118
                out = np_out[batch_offset + start : batch_offset + end]
119
                sum = np.sum(dout * out, keepdims=True)
120
                dx = np.concatenate([dx, (dout - sum) * out])
121

122
            batch_offset += np_csr.nnz
123

124
        np.testing.assert_allclose(csr.grad.values().numpy(), dx, rtol=1e-05)
125

126 127 128

if __name__ == "__main__":
    unittest.main()