test_dropout_op.py 1.1 KB
Newer Older
1 2
import unittest
import numpy as np
3
from op_test import OpTest
4 5


6
class TestDropoutOp(OpTest):
7
    def setUp(self):
8
        self.op_type = "dropout"
9 10 11 12
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
        self.attrs = {'dropout_prob': 0.0}
        self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}

13 14 15 16 17
    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['X'], 'Out', max_relative_error=0.05)
18 19


20
class TestDropoutOp2(TestDropoutOp):
21
    def setUp(self):
22
        self.op_type = "dropout"
23 24 25 26 27
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
        self.attrs = {'dropout_prob': 1.0}
        self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}


28
class TestDropoutOp3(TestDropoutOp):
29
    def setUp(self):
30 31
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
32
        self.attrs = {'dropout_prob': 0.0}
33
        self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
34 35 36 37


if __name__ == '__main__':
    unittest.main()