test_softmax_op.py 902 字节
Newer Older
Q
qijun 已提交
1
import unittest
Q
Qiao Longfei 已提交
2

Q
qijun 已提交
3
import numpy as np
Q
Qiao Longfei 已提交
4

5
from gradient_checker import GradientChecker, create_op
Q
Qiao Longfei 已提交
6
from op_test_util import OpTestMeta
Q
qijun 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20


def stable_softmax(x):
    """Compute the softmax of vector x in a numerically stable way."""
    shiftx = x - np.max(x)
    exps = np.exp(shiftx)
    return exps / np.sum(exps)


class TestSoftmaxOp(unittest.TestCase):
    __metaclass__ = OpTestMeta

    def setUp(self):
        self.type = "softmax"
D
dangqingqing 已提交
21 22 23 24
        self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
        self.outputs = {
            'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
        }
Q
qijun 已提交
25 26


27 28 29 30 31
class SoftmaxGradOpTest(GradientChecker):
    def test_softmax(self):
        op = create_op("softmax")
        inputs = {"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")}
        self.check_grad(op, inputs, set("X"), "Y")
Q
Qiao Longfei 已提交
32 33


Q
qijun 已提交
34 35
if __name__ == '__main__':
    unittest.main()