test_dropout_op.py 2.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import unittest
import numpy as np
17
from op_test import OpTest
18 19


20
class TestDropoutOp(OpTest):
21
    def setUp(self):
22
        self.op_type = "dropout"
23
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
24
        self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
Y
Yu Yang 已提交
25 26 27 28
        self.outputs = {
            'Out': self.inputs['X'],
            'Mask': np.ones((32, 64)).astype('float32')
        }
29

30 31 32 33 34
    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['X'], 'Out', max_relative_error=0.05)
35 36


37
class TestDropoutOp2(TestDropoutOp):
38
    def setUp(self):
39
        self.op_type = "dropout"
40
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
41
        self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False}
Y
Yu Yang 已提交
42 43 44 45
        self.outputs = {
            'Out': np.zeros((32, 64)).astype('float32'),
            'Mask': np.zeros((32, 64)).astype('float32')
        }
46 47


48
class TestDropoutOp3(TestDropoutOp):
49
    def setUp(self):
50 51
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
52
        self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
Y
Yu Yang 已提交
53 54 55 56
        self.outputs = {
            'Out': self.inputs['X'],
            'Mask': np.ones((32, 64, 2)).astype('float32')
        }
57 58


59 60 61 62
class TestDropoutOp4(OpTest):
    def setUp(self):
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
63
        self.attrs = {'dropout_prob': 0.35, 'fix_seed': True, 'is_test': True}
64 65 66
        self.outputs = {
            'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
        }
67 68 69 70 71 72 73 74 75

    def test_check_output(self):
        self.check_output()


class TestDropoutOp5(OpTest):
    def setUp(self):
        self.op_type = "dropout"
        self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
76
        self.attrs = {'dropout_prob': 0.75, 'is_test': True}
77 78 79
        self.outputs = {
            'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
        }
80 81 82 83 84

    def test_check_output(self):
        self.check_output()


85 86
if __name__ == '__main__':
    unittest.main()