test_compare_reduce_op.py 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import op_test
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard


W
wawltor 已提交
25
def create_test_not_equal_class(op_type, typename, callback):
26 27
    class Cls(op_test.OpTest):
        def setUp(self):
W
wawltor 已提交
28 29
            x = np.random.random(size=(10, 7)).astype(typename)
            y = np.random.random(size=(10, 7)).astype(typename)
30 31 32 33 34 35 36 37
            z = callback(x, y)
            self.inputs = {'X': x, 'Y': y}
            self.outputs = {'Out': z}
            self.op_type = op_type

        def test_output(self):
            self.check_output()

W
wawltor 已提交
38
    cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal_all')
39 40 41 42
    Cls.__name__ = cls_name
    globals()[cls_name] = Cls


W
wawltor 已提交
43
def create_test_not_shape_equal_class(op_type, typename, callback):
44 45 46
    class Cls(op_test.OpTest):
        def setUp(self):
            x = np.random.random(size=(10, 7)).astype(typename)
W
wawltor 已提交
47
            y = np.random.random(size=(10)).astype(typename)
48 49 50 51 52 53 54 55
            z = callback(x, y)
            self.inputs = {'X': x, 'Y': y}
            self.outputs = {'Out': z}
            self.op_type = op_type

        def test_output(self):
            self.check_output()

W
wawltor 已提交
56
    cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_shape_equal_all')
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    Cls.__name__ = cls_name
    globals()[cls_name] = Cls


def create_test_equal_class(op_type, typename, callback):
    class Cls(op_test.OpTest):
        def setUp(self):
            x = y = np.random.random(size=(10, 7)).astype(typename)
            z = callback(x, y)
            self.inputs = {'X': x, 'Y': y}
            self.outputs = {'Out': z}
            self.op_type = op_type

        def test_output(self):
            self.check_output()

W
wawltor 已提交
73
    cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    Cls.__name__ = cls_name
    globals()[cls_name] = Cls


def create_test_dim1_class(op_type, typename, callback):
    class Cls(op_test.OpTest):
        def setUp(self):
            x = y = np.random.random(size=(1)).astype(typename)
            z = callback(x, y)
            self.inputs = {'X': x, 'Y': y}
            self.outputs = {'Out': z}
            self.op_type = op_type

        def test_output(self):
            self.check_output()

W
wawltor 已提交
90
    cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
91 92 93 94 95 96 97
    Cls.__name__ = cls_name
    globals()[cls_name] = Cls


np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y))

for _type_name in {'float32', 'float64', 'int32', 'int64'}:
W
wawltor 已提交
98 99 100
    create_test_not_equal_class('equal_all', _type_name, np_equal)
    create_test_equal_class('equal_all', _type_name, np_equal)
    create_test_dim1_class('equal_all', _type_name, np_equal)
101 102 103 104 105 106


class TestEqualReduceAPI(unittest.TestCase):
    def test_name(self):
        x = fluid.layers.assign(np.array([3, 4], dtype="int32"))
        y = fluid.layers.assign(np.array([3, 4], dtype="int32"))
W
wawltor 已提交
107
        out = paddle.equal_all(x, y, name='equal_res')
108 109 110 111 112
        assert 'equal_res' in out.name


if __name__ == '__main__':
    unittest.main()