test_sparse_softmax_op.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.framework import _test_eager_guard

import numpy as np
import scipy.sparse as sp
import unittest

np.random.seed(2022)


class TestCsrSoftmax(unittest.TestCase):
26
    def test_softmax2d(self):
27
        with _test_eager_guard():
28 29
            mask = np.random.rand(16, 128) < 0.5
            np_x = np.random.rand(16, 128) * mask
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
            np_csr = sp.csr_matrix(np_x)

            row_number = np_csr.shape[0]
            np_out = np.array([])
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                x = np_csr.data[start:end]
                x_max = np.max(x, keepdims=True)
                x_exp = np.exp(x - x_max)
                x_exp_sum = np.sum(x_exp, keepdims=True)
                np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
46
            m = paddle.sparse.nn.Softmax()
47
            out = m(csr)
48 49 50 51 52 53
            np.testing.assert_allclose(
                out.crows().numpy(), np_csr.indptr, rtol=1e-05
            )
            np.testing.assert_allclose(
                out.cols().numpy(), np_csr.indices, rtol=1e-05
            )
54
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
55 56 57

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
58
            dx = np.array([])
59 60 61 62 63 64 65 66
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                out = np_out[start:end]
                dout = np_csr.data[start:end]
                sum = np.sum(dout * out, keepdims=True)
67
                dx = np.concatenate([dx, (dout - sum) * out])
68

69 70 71 72 73 74 75 76 77
            np.testing.assert_allclose(
                csr.grad.crows().numpy(), np_csr.indptr, rtol=1e-05
            )
            np.testing.assert_allclose(
                csr.grad.cols().numpy(), np_csr.indices, rtol=1e-05
            )
            np.testing.assert_allclose(
                csr.grad.values().numpy(), dx, rtol=1e-05
            )
78

79 80 81 82 83 84 85 86 87 88 89
    def test_softmax3d(self):
        with _test_eager_guard():
            batchNum = 16
            mask = np.random.rand(batchNum, 16, 128) < 0.5
            np_x = np.random.rand(batchNum, 16, 128) * mask

            np_out_list = []
            np_out = np.array([])
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
90 91 92
                for j in range(
                    row_number,
                ):
93 94 95 96 97 98 99 100 101 102 103 104
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    x = np_csr.data[start:end]
                    x_max = np.max(x, keepdims=True)
                    x_exp = np.exp(x - x_max)
                    x_exp_sum = np.sum(x_exp, keepdims=True)
                    np_out_list.append(x_exp / x_exp_sum)
                    np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
105
            m = paddle.sparse.nn.Softmax()
106
            out = m(csr)
107
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
108 109 110 111 112 113 114 115 116 117 118 119 120 121

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
            dx = np.array([])
            batch_offset = 0
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
                for j in range(row_number):
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    dout = np_csr.data[start:end]
122
                    out = np_out[batch_offset + start : batch_offset + end]
123 124 125 126 127
                    sum = np.sum(dout * out, keepdims=True)
                    dx = np.concatenate([dx, (dout - sum) * out])

                batch_offset += np_csr.nnz

128 129 130
            np.testing.assert_allclose(
                csr.grad.values().numpy(), dx, rtol=1e-05
            )
131

132 133 134

if __name__ == "__main__":
    unittest.main()