test_op_transform.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#!/usr/bin/env python3

# Copyright (c) 2021 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17 18 19
import os
import unittest

20
from cinn import framework
21
from test_utils import SingleOpTester
22

23
import paddle
24
from paddle import static
25

26 27 28 29 30 31 32 33 34 35 36 37 38 39
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def matmul_util(inputs_data, input_shape, trans_a, trans_b, alpha):
    main_program = static.Program()
    paddle.enable_static()
    with static.program_guard(main_program, static.Program()):
        [input_x, input_y] = inputs_data
        x = static.data(name='x', shape=input_shape[0], dtype='float32')
        y = static.data(name='y', shape=input_shape[1], dtype='float32')
        output = paddle.matmul(x, y, trans_a, trans_b)
        output = paddle.scale(output, scale=alpha)
        exe = static.Executor(paddle.CPUPlace())
        exe.run(static.default_startup_program())
40
        (res,) = exe.run(
41
            static.default_main_program(),
42 43 44
            feed={'x': input_x, 'y': input_y},
            fetch_list=[output],
        )
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        return res


class OpTest_matmul_0(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[100, 32], [32, 100]]
        self.output_shape = [[100, 100], [100, 100]]
        self.trans_a = False
        self.trans_b = False
        self.alpha = 1.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
61 62 63 64 65 66 67
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
68 69 70

    def test_op(self):
        self.init_testcase()
71 72 73
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88


class OpTest_matmul_1(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[100, 32], [100, 32]]
        self.output_shape = [[100, 100], [2, 32, 50]]
        self.trans_a = False
        self.trans_b = True
        self.alpha = 2.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
89 90 91 92 93 94 95
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
96 97 98

    def test_op(self):
        self.init_testcase()
99 100 101
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116


class OpTest_matmul_2(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[2, 3, 100, 32], [2, 3, 100, 32]]
        self.output_shape = [[2, 3, 100, 100], [2, 3, 2, 100, 16]]
        self.trans_a = False
        self.trans_b = True
        self.alpha = 2.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
117 118 119 120 121 122 123
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
124 125 126

    def test_op(self):
        self.init_testcase()
127 128 129
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144


class OpTest_matmul_3(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[32, 100], [32, 100]]
        self.output_shape = [[100, 100], [2, 100, 16]]
        self.trans_a = True
        self.trans_b = False
        self.alpha = 2.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
145 146 147 148 149 150 151
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
152 153 154

    def test_op(self):
        self.init_testcase()
155 156 157
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172


class OpTest_matmul_4(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[32, 100], [100]]
        self.output_shape = [[32], [2, 100, 16]]
        self.trans_a = False
        self.trans_b = False
        self.alpha = 2.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
173 174 175 176 177 178 179
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
180 181 182

    def test_op(self):
        self.init_testcase()
183 184 185
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200


class OpTest_matmul_5(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[100], [100]]
        self.output_shape = [[1], [1, 100, 1]]
        self.trans_a = False
        self.trans_b = False
        self.alpha = 2.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
201 202 203 204 205 206 207
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
208 209 210

    def test_op(self):
        self.init_testcase()
211 212 213
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228


class OpTest_matmul_6(SingleOpTester):
    def init_testcase(self):
        self.input_shape = [[32, 1], [1, 100]]
        self.output_shape = [[32, 100], [2, 1, 50]]
        self.trans_a = False
        self.trans_b = False
        self.alpha = 2.0
        self.attrs = framework.NodeAttr()
        self.attrs.set_attr("trans_a", self.trans_a)
        self.attrs.set_attr("trans_b", self.trans_b)
        self.attrs.set_attr("alpha", self.alpha)

    def create_target_data(self, inputs_data, attrs):
229 230 231 232 233 234 235
        return matmul_util(
            inputs_data,
            self.input_shape,
            self.trans_a,
            self.trans_b,
            self.alpha,
        )
236 237 238

    def test_op(self):
        self.init_testcase()
239 240 241
        self.to_test_op(
            self.input_shape, self.output_shape, "matmul", self.attrs, 0
        )
242 243 244 245


if __name__ == "__main__":
    unittest.main()