test_softmax_op.py 991 字节
Newer Older
Q
qijun 已提交
1
import unittest
Q
Qiao Longfei 已提交
2

Q
qijun 已提交
3
import numpy as np
Q
Qiao Longfei 已提交
4

5
from gradient_checker import GradientChecker, create_op
Q
Qiao Longfei 已提交
6
from op_test_util import OpTestMeta
Q
qijun 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20


def stable_softmax(x):
    """Compute the softmax of vector x in a numerically stable way."""
    shiftx = x - np.max(x)
    exps = np.exp(shiftx)
    return exps / np.sum(exps)


class TestSoftmaxOp(unittest.TestCase):
    __metaclass__ = OpTestMeta

    def setUp(self):
        self.type = "softmax"
21
        self.inputs = {"Logits": np.random.random((10, 10)).astype("float32")}
D
dangqingqing 已提交
22
        self.outputs = {
23
            "Out": np.apply_along_axis(stable_softmax, 1, self.inputs["Logits"])
D
dangqingqing 已提交
24
        }
Q
qijun 已提交
25 26


C
caoying03 已提交
27 28 29 30
class TestSoftmaxGradOp(GradientChecker):
    def setUp(self):
        self.op = create_op("softmax")
        self.inputs = {
31
            "Logits": np.random.uniform(0.1, 1, [10, 10]).astype("float32")
C
caoying03 已提交
32 33 34
        }

    def test_softmax_grad(self):
35
        self.check_grad(self.op, self.inputs, ["Logits"], "Out")
Q
Qiao Longfei 已提交
36 37


C
caoying03 已提交
38
if __name__ == "__main__":
Q
qijun 已提交
39
    unittest.main()