test_reduce_op.py 3.3 KB
Newer Older
G
guosheng 已提交
1 2
import unittest
import numpy as np
3
from op_test import OpTest
G
guosheng 已提交
4 5


6
class TestSumOp(OpTest):
G
guosheng 已提交
7
    def setUp(self):
8
        self.op_type = "reduce_sum"
G
guosheng 已提交
9
        self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
10
        self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
G
guosheng 已提交
11

12 13
    def test_check_output(self):
        self.check_output()
G
guosheng 已提交
14

15 16
    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
G
guosheng 已提交
17 18


19 20 21 22 23 24
class TestMeanOp(OpTest):
    def setUp(self):
        self.op_type = "reduce_mean"
        self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
        self.attrs = {'dim': 1}
        self.outputs = {'Out': self.inputs['X'].mean(axis=self.attrs['dim'])}
G
guosheng 已提交
25

26 27
    def test_check_output(self):
        self.check_output()
G
guosheng 已提交
28

29 30
    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
G
guosheng 已提交
31 32


33 34
class TestMaxOp(OpTest):
    """Remove Max with subgradient from gradient check to confirm the success of CI."""
G
guosheng 已提交
35 36

    def setUp(self):
37
        self.op_type = "reduce_max"
G
guosheng 已提交
38 39
        self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
        self.attrs = {'dim': -1}
40 41 42 43
        self.outputs = {'Out': self.inputs['X'].max(axis=self.attrs['dim'])}

    def test_check_output(self):
        self.check_output()
G
guosheng 已提交
44 45


46 47
class TestMinOp(OpTest):
    """Remove Min with subgradient from gradient check to confirm the success of CI."""
G
guosheng 已提交
48

49 50 51 52 53
    def setUp(self):
        self.op_type = "reduce_min"
        self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
        self.attrs = {'dim': 2}
        self.outputs = {'Out': self.inputs['X'].min(axis=self.attrs['dim'])}
G
guosheng 已提交
54

55 56
    def test_check_output(self):
        self.check_output()
G
guosheng 已提交
57 58


59
class TestKeepDimReduce(OpTest):
G
guosheng 已提交
60
    def setUp(self):
61
        self.op_type = "reduce_sum"
G
guosheng 已提交
62
        self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
63
        self.attrs = {'dim': -2, 'keep_dim': True}
64 65 66 67 68 69
        self.outputs = {
            'Out': self.inputs['X'].sum(axis=self.attrs['dim'], keepdims=True)
        }

    def test_check_output(self):
        self.check_output()
G
guosheng 已提交
70

71 72
    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
G
guosheng 已提交
73 74


75
class Test1DReduce(OpTest):
G
guosheng 已提交
76
    def setUp(self):
77 78 79 80 81 82 83 84 85
        self.op_type = "reduce_sum"
        self.inputs = {'X': np.random.random(20).astype("float32")}
        self.outputs = {'Out': self.inputs['X'].sum(axis=0)}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
G
guosheng 已提交
86 87


G
guosheng 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
class TestNorm(OpTest):
    def setUp(self):
        # use x away from 0 to avoid errors of numerical gradient when gradient near 0
        x = np.random.random((5, 6, 10)).astype("float32") + 0.2
        p = 2
        dim = 1
        keep_dim = False
        abs_out = np.absolute(x)
        pow_out = np.power(x, p)
        sum_out = np.sum(pow_out, axis=dim, keepdims=keep_dim)
        out = np.power(sum_out, 1. / p)
        self.op_type = "norm"
        self.inputs = {'X': x}
        self.attrs = {"p": p, "dim": dim, "keep_dim": keep_dim}
        self.outputs = {
            "AbsOut": abs_out,
            "PowOut": pow_out,
            "SumOut": sum_out,
            "Out": out
        }

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X'], 'Out', max_relative_error=0.01)


G
guosheng 已提交
116 117
if __name__ == '__main__':
    unittest.main()