test_logical_xor_op.py 6.1 KB
Newer Older
6
6clc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2023 CINN Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15

6
6clc 已提交
16
import numpy as np
17 18
from cinn.common import *
from cinn.frontend import *
6
6clc 已提交
19 20
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
21

6
6clc 已提交
22 23 24
import paddle


25 26 27
@OpTestTool.skip_if(
    not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
6
6clc 已提交
28 29 30 31 32 33 34 35 36 37
class TestLogicalXorOp(OpTest):
    def setUp(self):
        print(f"\nRunning {self.__class__.__name__}: {self.case}")
        self.prepare_inputs()

    def prepare_inputs(self):
        self.x_np = self.random(
            shape=self.case["x_shape"],
            dtype=self.case["x_dtype"],
            low=-10,
38 39
            high=100,
        )
6
6clc 已提交
40 41 42 43
        self.y_np = self.random(
            shape=self.case["y_shape"],
            dtype=self.case["y_dtype"],
            low=-10,
44 45
            high=100,
        )
6
6clc 已提交
46 47 48 49 50 51 52 53

    def build_paddle_program(self, target):
        x = paddle.to_tensor(self.x_np, stop_gradient=False)
        y = paddle.to_tensor(self.y_np, stop_gradient=False)

        def get_unsqueeze_axis(x_rank, y_rank, axis):
            self.assertTrue(
                x_rank >= y_rank,
54 55
                "The rank of x should be greater or equal to that of y.",
            )
6
6clc 已提交
56
            axis = axis if axis >= 0 else x_rank - y_rank
57 58 59 60
            unsqueeze_axis = (
                np.arange(0, axis).tolist()
                + np.arange(axis + y_rank, x_rank).tolist()
            )
6
6clc 已提交
61 62 63
            return unsqueeze_axis

        unsqueeze_axis = get_unsqueeze_axis(
64 65 66 67 68 69 70
            len(x.shape), len(y.shape), self.case["axis"]
        )
        y_t = (
            paddle.unsqueeze(y, axis=unsqueeze_axis)
            if len(unsqueeze_axis) > 0
            else y
        )
6
6clc 已提交
71 72 73 74 75 76 77
        out = paddle.logical_xor(x, y_t)

        self.paddle_outputs = [out]

    def build_cinn_program(self, target):
        builder = NetBuilder("logical_and")
        x = builder.create_input(
78 79 80 81
            self.nptype2cinntype(self.case["x_dtype"]),
            self.case["x_shape"],
            "x",
        )
6
6clc 已提交
82
        y = builder.create_input(
83 84 85 86
            self.nptype2cinntype(self.case["y_dtype"]),
            self.case["y_shape"],
            "y",
        )
6
6clc 已提交
87 88 89
        out = builder.logical_xor(x, y, axis=self.case["axis"])

        prog = builder.build()
90 91 92
        res = self.get_cinn_output(
            prog, target, [x, y], [self.x_np, self.y_np], [out]
        )
6
6clc 已提交
93 94 95 96

        self.cinn_outputs = res

    def test_check_results(self):
97 98 99 100 101
        max_relative_error = (
            self.case["max_relative_error"]
            if "max_relative_error" in self.case
            else 1e-5
        )
6
6clc 已提交
102 103 104 105 106 107 108 109
        self.check_outputs_and_grads(max_relative_error=max_relative_error)


class TestLogicalXorCase1(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestLogicalXorCase1"
        self.cls = TestLogicalXorOp
        self.inputs = [{"x_shape": [512, 256], "y_shape": [512, 256]}]
110 111 112 113 114 115 116 117 118
        self.dtypes = [
            {"x_dtype": "bool", "y_dtype": "bool"},
            {"x_dtype": "int8", "y_dtype": "int8"},
            {"x_dtype": "int16", "y_dtype": "int16"},
            {"x_dtype": "int32", "y_dtype": "int32"},
            {"x_dtype": "int64", "y_dtype": "int64"},
            {"x_dtype": "float32", "y_dtype": "float32"},
            {"x_dtype": "float64", "y_dtype": "float64"},
        ]
6
6clc 已提交
119 120 121 122 123 124 125
        self.attrs = [{"axis": -1}]


class TestLogicalXorCase2(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestLogicalXorCase2"
        self.cls = TestLogicalXorOp
126 127 128 129 130 131 132 133 134 135
        self.inputs = [
            {"x_shape": [1], "y_shape": [1]},
            {"x_shape": [1024], "y_shape": [1024]},
            {"x_shape": [512, 256], "y_shape": [512, 256]},
            {"x_shape": [128, 64, 32], "y_shape": [128, 64, 32]},
            {"x_shape": [128, 2048, 32], "y_shape": [128, 2048, 32]},
            {"x_shape": [16, 8, 4, 2], "y_shape": [16, 8, 4, 2]},
            {"x_shape": [1, 1, 1, 1], "y_shape": [1, 1, 1, 1]},
            {"x_shape": [16, 8, 4, 2, 1], "y_shape": [16, 8, 4, 2, 1]},
        ]
6
6clc 已提交
136 137 138 139 140 141 142 143 144
        self.dtypes = [{"x_dtype": "bool", "y_dtype": "bool"}]
        self.attrs = [{"axis": -1}]


class TestLogicalXorCaseWithBroadcast1(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestLogicalXorCaseWithBroadcast1"
        self.cls = TestLogicalXorOp
        self.inputs = [{"x_shape": [56], "y_shape": [1]}]
145 146 147 148 149 150 151 152 153
        self.dtypes = [
            {"x_dtype": "bool", "y_dtype": "bool"},
            {"x_dtype": "int8", "y_dtype": "int8"},
            {"x_dtype": "int16", "y_dtype": "int16"},
            {"x_dtype": "int32", "y_dtype": "int32"},
            {"x_dtype": "int64", "y_dtype": "int64"},
            {"x_dtype": "float32", "y_dtype": "float32"},
            {"x_dtype": "float64", "y_dtype": "float64"},
        ]
6
6clc 已提交
154 155 156 157 158 159 160
        self.attrs = [{"axis": -1}]


class TestLogicalXorCaseWithBroadcast2(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestLogicalXorCaseWithBroadcast2"
        self.cls = TestLogicalXorOp
161 162 163 164 165 166 167 168
        self.inputs = [
            {"x_shape": [56], "y_shape": [1]},
            {"x_shape": [1024], "y_shape": [1]},
            {"x_shape": [512, 256], "y_shape": [512, 1]},
            {"x_shape": [128, 64, 32], "y_shape": [128, 64, 1]},
            {"x_shape": [16, 1, 1, 2], "y_shape": [16, 8, 4, 2]},
            {"x_shape": [16, 1, 1, 2, 1], "y_shape": [16, 8, 4, 2, 1]},
        ]
6
6clc 已提交
169 170 171 172 173 174 175 176 177
        self.dtypes = [{"x_dtype": "bool", "y_dtype": "bool"}]
        self.attrs = [{"axis": -1}]


if __name__ == "__main__":
    TestLogicalXorCase1().run()
    TestLogicalXorCase2().run()
    TestLogicalXorCaseWithBroadcast1().run()
    TestLogicalXorCaseWithBroadcast2().run()