test_dropout_op.py 2.0 KB
Newer Older
1 2
import unittest
import numpy as np
3
from op_test import OpTest
4 5


6
class TestDropoutOp(OpTest):
7
    def setUp(self):
8
        self.op_type = "dropout"
9
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
10
        self.attrs = {'dropout_prob': 0.0, 'is_test': False}
Y
Yu Yang 已提交
11 12 13 14
        self.outputs = {
            'Out': self.inputs['X'],
            'Mask': np.ones((32, 64)).astype('float32')
        }
15

16 17 18 19 20
    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['X'], 'Out', max_relative_error=0.05)
21 22


23
class TestDropoutOp2(TestDropoutOp):
24
    def setUp(self):
25
        self.op_type = "dropout"
26
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
27
        self.attrs = {'dropout_prob': 1.0, 'is_test': False}
Y
Yu Yang 已提交
28 29 30 31
        self.outputs = {
            'Out': np.zeros((32, 64)).astype('float32'),
            'Mask': np.zeros((32, 64)).astype('float32')
        }
32 33


34
class TestDropoutOp3(TestDropoutOp):
35
    def setUp(self):
36 37
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
38
        self.attrs = {'dropout_prob': 0.0, 'is_test': False}
Y
Yu Yang 已提交
39 40 41 42
        self.outputs = {
            'Out': self.inputs['X'],
            'Mask': np.ones((32, 64, 2)).astype('float32')
        }
43 44


45 46 47 48
class TestDropoutOp4(OpTest):
    def setUp(self):
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
49
        self.attrs = {'dropout_prob': 0.35, 'is_test': True}
50
        self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
51 52 53 54 55 56 57 58 59

    def test_check_output(self):
        self.check_output()


class TestDropoutOp5(OpTest):
    def setUp(self):
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
60
        self.attrs = {'dropout_prob': 0.75, 'is_test': True}
61
        self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
62 63 64 65 66

    def test_check_output(self):
        self.check_output()


67 68
if __name__ == '__main__':
    unittest.main()