test_dropout_op.py 1.9 KB
Newer Older
1 2
import unittest
import numpy as np
3
from op_test import OpTest
4 5


6
class TestDropoutOp(OpTest):
7
    def setUp(self):
8
        self.op_type = "dropout"
9
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
10
        self.attrs = {'dropout_prob': 0.0, 'is_training': True}
11 12
        self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}

13 14 15 16 17
    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['X'], 'Out', max_relative_error=0.05)
18 19


20
class TestDropoutOp2(TestDropoutOp):
21
    def setUp(self):
22
        self.op_type = "dropout"
23
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
24
        self.attrs = {'dropout_prob': 1.0, 'is_training': True}
25 26 27
        self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}


28
class TestDropoutOp3(TestDropoutOp):
29
    def setUp(self):
30 31
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
32
        self.attrs = {'dropout_prob': 0.0, 'is_training': True}
33
        self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
34 35


36 37 38 39
class TestDropoutOp4(OpTest):
    def setUp(self):
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
40
        self.attrs = {'dropout_prob': 0.35, 'is_training': False}
41
        self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
42 43 44 45 46 47 48 49 50

    def test_check_output(self):
        self.check_output()


class TestDropoutOp5(OpTest):
    def setUp(self):
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
51
        self.attrs = {'dropout_prob': 0.75, 'is_training': False}
52
        self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
53 54 55 56 57

    def test_check_output(self):
        self.check_output()


58 59
if __name__ == '__main__':
    unittest.main()