test_sparse_softmax_op.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16 17 18

import numpy as np
import scipy.sparse as sp
19 20 21

import paddle
from paddle.fluid.framework import _test_eager_guard
22 23 24 25 26

np.random.seed(2022)


class TestCsrSoftmax(unittest.TestCase):
27
    def test_softmax2d(self):
28
        with _test_eager_guard():
29 30
            mask = np.random.rand(16, 128) < 0.5
            np_x = np.random.rand(16, 128) * mask
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
            np_csr = sp.csr_matrix(np_x)

            row_number = np_csr.shape[0]
            np_out = np.array([])
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                x = np_csr.data[start:end]
                x_max = np.max(x, keepdims=True)
                x_exp = np.exp(x - x_max)
                x_exp_sum = np.sum(x_exp, keepdims=True)
                np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
47
            m = paddle.sparse.nn.Softmax()
48
            out = m(csr)
49 50 51 52 53 54
            np.testing.assert_allclose(
                out.crows().numpy(), np_csr.indptr, rtol=1e-05
            )
            np.testing.assert_allclose(
                out.cols().numpy(), np_csr.indices, rtol=1e-05
            )
55
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
56 57 58

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
59
            dx = np.array([])
60 61 62 63 64 65 66 67
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                out = np_out[start:end]
                dout = np_csr.data[start:end]
                sum = np.sum(dout * out, keepdims=True)
68
                dx = np.concatenate([dx, (dout - sum) * out])
69

70 71 72 73 74 75 76 77 78
            np.testing.assert_allclose(
                csr.grad.crows().numpy(), np_csr.indptr, rtol=1e-05
            )
            np.testing.assert_allclose(
                csr.grad.cols().numpy(), np_csr.indices, rtol=1e-05
            )
            np.testing.assert_allclose(
                csr.grad.values().numpy(), dx, rtol=1e-05
            )
79

80 81 82 83 84 85 86 87 88 89 90
    def test_softmax3d(self):
        with _test_eager_guard():
            batchNum = 16
            mask = np.random.rand(batchNum, 16, 128) < 0.5
            np_x = np.random.rand(batchNum, 16, 128) * mask

            np_out_list = []
            np_out = np.array([])
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
91 92 93
                for j in range(
                    row_number,
                ):
94 95 96 97 98 99 100 101 102 103 104 105
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    x = np_csr.data[start:end]
                    x_max = np.max(x, keepdims=True)
                    x_exp = np.exp(x - x_max)
                    x_exp_sum = np.sum(x_exp, keepdims=True)
                    np_out_list.append(x_exp / x_exp_sum)
                    np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
106
            m = paddle.sparse.nn.Softmax()
107
            out = m(csr)
108
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
109 110 111 112 113 114 115 116 117 118 119 120 121 122

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
            dx = np.array([])
            batch_offset = 0
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
                for j in range(row_number):
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    dout = np_csr.data[start:end]
123
                    out = np_out[batch_offset + start : batch_offset + end]
124 125 126 127 128
                    sum = np.sum(dout * out, keepdims=True)
                    dx = np.concatenate([dx, (dout - sum) * out])

                batch_offset += np_csr.nnz

129 130 131
            np.testing.assert_allclose(
                csr.grad.values().numpy(), dx, rtol=1e-05
            )
132

133 134 135

if __name__ == "__main__":
    unittest.main()