test_sparse_softmax_op.py 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.framework import _test_eager_guard

import numpy as np
import scipy
import scipy.sparse as sp
import unittest
import os
import re
import math

np.random.seed(2022)


class TestCsrSoftmax(unittest.TestCase):

31
    def test_softmax2d(self):
32
        with _test_eager_guard():
33 34
            mask = np.random.rand(16, 128) < 0.5
            np_x = np.random.rand(16, 128) * mask
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
            np_csr = sp.csr_matrix(np_x)

            row_number = np_csr.shape[0]
            np_out = np.array([])
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                x = np_csr.data[start:end]
                x_max = np.max(x, keepdims=True)
                x_exp = np.exp(x - x_max)
                x_exp_sum = np.sum(x_exp, keepdims=True)
                np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
            m = paddle.incubate.sparse.nn.Softmax()
            out = m(csr)
53 54 55 56 57 58 59
            np.testing.assert_allclose(out.crows().numpy(),
                                       np_csr.indptr,
                                       rtol=1e-05)
            np.testing.assert_allclose(out.cols().numpy(),
                                       np_csr.indices,
                                       rtol=1e-05)
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
60 61 62

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
63
            dx = np.array([])
64 65 66 67 68 69 70 71
            for i in range(row_number):
                start = np_csr.indptr[i]
                end = np_csr.indptr[i + 1]
                if start == end:
                    continue
                out = np_out[start:end]
                dout = np_csr.data[start:end]
                sum = np.sum(dout * out, keepdims=True)
72
                dx = np.concatenate([dx, (dout - sum) * out])
73

74 75 76 77 78 79 80 81 82
            np.testing.assert_allclose(csr.grad.crows().numpy(),
                                       np_csr.indptr,
                                       rtol=1e-05)
            np.testing.assert_allclose(csr.grad.cols().numpy(),
                                       np_csr.indices,
                                       rtol=1e-05)
            np.testing.assert_allclose(csr.grad.values().numpy(),
                                       dx,
                                       rtol=1e-05)
83

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    def test_softmax3d(self):
        with _test_eager_guard():
            batchNum = 16
            mask = np.random.rand(batchNum, 16, 128) < 0.5
            np_x = np.random.rand(batchNum, 16, 128) * mask

            np_out_list = []
            np_out = np.array([])
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
                for j in range(row_number, ):
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    x = np_csr.data[start:end]
                    x_max = np.max(x, keepdims=True)
                    x_exp = np.exp(x - x_max)
                    x_exp_sum = np.sum(x_exp, keepdims=True)
                    np_out_list.append(x_exp / x_exp_sum)
                    np_out = np.concatenate([np_out, x_exp / x_exp_sum])

            csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr()
            m = paddle.incubate.sparse.nn.Softmax()
            out = m(csr)
110
            np.testing.assert_allclose(out.values().numpy(), np_out, rtol=1e-05)
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

            # dx = (dout - sum(dout * out)) * out, dout=rand_x
            out.backward(csr.detach())
            dx = np.array([])
            batch_offset = 0
            for i in range(batchNum):
                np_csr = sp.csr_matrix(np_x[i, :, :])
                row_number = np_csr.shape[0]
                for j in range(row_number):
                    start = np_csr.indptr[j]
                    end = np_csr.indptr[j + 1]
                    if start == end:
                        continue
                    dout = np_csr.data[start:end]
                    out = np_out[batch_offset + start:batch_offset + end]
                    sum = np.sum(dout * out, keepdims=True)
                    dx = np.concatenate([dx, (dout - sum) * out])

                batch_offset += np_csr.nnz

131 132 133
            np.testing.assert_allclose(csr.grad.values().numpy(),
                                       dx,
                                       rtol=1e-05)
134

135 136 137

if __name__ == "__main__":
    unittest.main()