test_prelu_op.py 1.7 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
zchen0211 已提交
15 16 17 18 19
import unittest
import numpy as np
from op_test import OpTest


Z
zchen0211 已提交
20
class PReluTest(OpTest):
Z
zchen0211 已提交
21 22
    def setUp(self):
        self.op_type = "prelu"
Z
zchen0211 已提交
23
        x_np = np.random.normal(size=(10, 10)).astype("float32")
Y
Yu Yang 已提交
24 25 26 27 28 29 30 31

        for pos, val in np.ndenumerate(x_np):
            # Since zero point in prelu is not differentiable, avoid randomize
            # zero.
            while abs(val) < 1e-3:
                x_np[pos] = np.random.normal()
                val = x_np[pos]

Z
zchen0211 已提交
32 33
        x_np_sign = np.sign(x_np)
        x_np = x_np_sign * np.maximum(x_np, .005)
Y
Yu Yang 已提交
34
        alpha_np = np.array([.1], dtype="float32")
Z
zchen0211 已提交
35
        self.inputs = {'X': x_np, 'Alpha': alpha_np}
Z
zchen0211 已提交
36
        out_np = np.maximum(self.inputs['X'], 0.)
Z
zchen0211 已提交
37 38
        out_np = out_np + np.minimum(self.inputs['X'],
                                     0.) * self.inputs['Alpha']
Z
zchen0211 已提交
39 40
        assert out_np is not self.inputs['X']
        self.outputs = {'Out': out_np}
Z
zchen0211 已提交
41

42
    def test_check_output(self):
Z
zchen0211 已提交
43 44
        self.check_output()

45
    def test_check_grad(self):
Z
zchen0211 已提交
46 47 48 49 50
        self.check_grad(['X'], 'Out')


if __name__ == "__main__":
    unittest.main()