test_mkldnn_elt_act_fuse_pass.py 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18
import numpy as np
from inference_pass_test import InferencePassTest
19

20
import paddle
21
import paddle.nn.functional as F
22
from paddle import fluid
23 24 25
from paddle.fluid.core import PassVersionChecker


26
class ElementwiseActivationOneDNNFusePassTest(InferencePassTest):
27 28
    act_alpha = None
    act_beta = None
29
    pass_name = 'elementwise_act_onednn_fuse_pass'
30 31 32 33

    def setUp(self):
        self.set_params()
        with fluid.program_guard(self.main_program, self.startup_program):
34
            data_A = paddle.static.data(
35 36
                name="data_A", shape=[-1, 3, 100, 100], dtype="float32"
            )
37
            data_B = paddle.static.data(
38 39
                name="data_B", shape=[-1, 3, 100, 100], dtype="float32"
            )
40 41 42 43 44 45 46 47 48 49 50
            elt_out = self.operand(data_A, data_B)
            if self.act is not None:
                if self.act_beta is not None:
                    elt_out = self.act(elt_out, self.act_alpha, self.act_beta)
                elif self.act_alpha is not None:
                    elt_out = self.act(elt_out, self.act_alpha)
                else:
                    elt_out = self.act(elt_out)

        self.feeds = {
            "data_A": np.random.random((1, 3, 100, 100)).astype("float32"),
51
            "data_B": np.random.random((1, 3, 100, 100)).astype("float32"),
52 53 54 55 56
        }
        self.fetch_list = [elt_out]
        self.enable_mkldnn = True

    def set_params(self):
57
        self.operand = paddle.add
58 59 60 61 62 63 64 65 66 67
        self.act = None

    def test_check_output(self):
        use_gpu = False
        self.check_output_with_option(use_gpu)

    def test_pass_compatible(self):
        self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))


68 69
class ElementwiseActivationOneDNNFusePassTest_Add_Relu(
    ElementwiseActivationOneDNNFusePassTest
70
):
71
    def set_params(self):
72
        self.operand = paddle.add
73
        self.act = F.relu
74 75


76 77
class ElementwiseActivationOneDNNFusePassTest_Add_Tanh(
    ElementwiseActivationOneDNNFusePassTest
78
):
79
    def set_params(self):
80
        self.operand = paddle.add
81
        self.act = paddle.tanh
82 83


84 85
class ElementwiseActivationOneDNNFusePassTest_Add_LeakyRelu(
    ElementwiseActivationOneDNNFusePassTest
86
):
87
    def set_params(self):
88
        self.operand = paddle.add
89
        self.act_alpha = 0.2
90
        self.act = paddle.nn.functional.leaky_relu
91 92


93 94
class ElementwiseActivationOneDNNFusePassTest_Add_Swish(
    ElementwiseActivationOneDNNFusePassTest
95
):
96
    def set_params(self):
97
        self.operand = paddle.add
98
        self.act = paddle.nn.functional.swish
99 100


101 102
class ElementwiseActivationOneDNNFusePassTest_Add_HardSwish(
    ElementwiseActivationOneDNNFusePassTest
103
):
104
    def set_params(self):
105
        self.operand = paddle.add
106
        self.act = paddle.nn.functional.hardswish
107 108


109 110
class ElementwiseActivationOneDNNFusePassTest_Add_SQRT(
    ElementwiseActivationOneDNNFusePassTest
111
):
112
    def set_params(self):
113
        self.operand = paddle.add
114
        self.act = paddle.sqrt
115 116


117 118
class ElementwiseActivationOneDNNFusePassTest_Add_ABS(
    ElementwiseActivationOneDNNFusePassTest
119
):
120
    def set_params(self):
121
        self.operand = paddle.add
122
        self.act = paddle.abs
123 124


125 126
class ElementwiseActivationOneDNNFusePassTest_Add_Clip(
    ElementwiseActivationOneDNNFusePassTest
127
):
128
    def set_params(self):
129
        self.operand = paddle.add
130
        self.act = paddle.clip
131 132 133 134
        self.act_alpha = 0.0
        self.act_beta = 10.0


135 136
class ElementwiseActivationOneDNNFusePassTest_Add_Gelu(
    ElementwiseActivationOneDNNFusePassTest
137
):
138
    def set_params(self):
139
        self.operand = paddle.add
140
        self.act = paddle.nn.functional.gelu
141 142


143 144
class ElementwiseActivationOneDNNFusePassTest_Add_Gelu_Tanh(
    ElementwiseActivationOneDNNFusePassTest
145
):
146
    def set_params(self):
147
        self.operand = paddle.add
148
        self.act = paddle.nn.functional.gelu
149 150 151
        self.act_alpha = True


152 153
class ElementwiseActivationOneDNNFusePassTest_Add_Relu6(
    ElementwiseActivationOneDNNFusePassTest
154
):
155
    def set_params(self):
156
        self.operand = paddle.add
157
        self.act = paddle.nn.functional.relu6
158 159


160 161
class ElementwiseActivationOneDNNFusePassTest_Add_Sigmoid(
    ElementwiseActivationOneDNNFusePassTest
162
):
163
    def set_params(self):
164
        self.operand = paddle.add
165
        self.act = paddle.nn.functional.sigmoid
166 167


168 169
class ElementwiseActivationOneDNNFusePassTest_Sub_Relu(
    ElementwiseActivationOneDNNFusePassTest
170
):
171
    def set_params(self):
172
        self.operand = paddle.subtract
173
        self.act = F.relu
174 175


176 177
class ElementwiseActivationOneDNNFusePassTest_Sub_Tanh(
    ElementwiseActivationOneDNNFusePassTest
178
):
179
    def set_params(self):
180
        self.operand = paddle.subtract
181
        self.act = paddle.tanh
182 183


184 185
class ElementwiseActivationOneDNNFusePassTest_Sub_LeakyRelu(
    ElementwiseActivationOneDNNFusePassTest
186
):
187
    def set_params(self):
188
        self.operand = paddle.subtract
189
        self.act_alpha = 0.2
190
        self.act = paddle.nn.functional.leaky_relu
191 192


193 194
class ElementwiseActivationOneDNNFusePassTest_Sub_Swish(
    ElementwiseActivationOneDNNFusePassTest
195
):
196
    def set_params(self):
197
        self.operand = paddle.subtract
198
        self.act = paddle.nn.functional.swish
199 200


201 202
class ElementwiseActivationOneDNNFusePassTest_Sub_HardSwish(
    ElementwiseActivationOneDNNFusePassTest
203
):
204
    def set_params(self):
205
        self.operand = paddle.subtract
206
        self.act = paddle.nn.functional.hardswish
207 208


209 210
class ElementwiseActivationOneDNNFusePassTest_Sub_ABS(
    ElementwiseActivationOneDNNFusePassTest
211
):
212
    def set_params(self):
213
        self.operand = paddle.subtract
214
        self.act = paddle.abs
215 216


217 218
class ElementwiseActivationOneDNNFusePassTest_Sub_Clip(
    ElementwiseActivationOneDNNFusePassTest
219
):
220
    def set_params(self):
221
        self.operand = paddle.subtract
222
        self.act = paddle.clip
223 224 225 226
        self.act_alpha = 0.0
        self.act_beta = 10.0


227 228
class ElementwiseActivationOneDNNFusePassTest_Sub_Gelu(
    ElementwiseActivationOneDNNFusePassTest
229
):
230
    def set_params(self):
231
        self.operand = paddle.subtract
232
        self.act = paddle.nn.functional.gelu
233 234


235 236
class ElementwiseActivationOneDNNFusePassTest_Sub_Gelu_Tanh(
    ElementwiseActivationOneDNNFusePassTest
237
):
238
    def set_params(self):
239
        self.operand = paddle.subtract
240
        self.act = paddle.nn.functional.gelu
241 242 243
        self.act_alpha = True


244 245
class ElementwiseActivationOneDNNFusePassTest_Sub_Relu6(
    ElementwiseActivationOneDNNFusePassTest
246
):
247
    def set_params(self):
248
        self.operand = paddle.subtract
249
        self.act = paddle.nn.functional.relu6
250 251


252 253
class ElementwiseActivationOneDNNFusePassTest_Sub_Sigmoid(
    ElementwiseActivationOneDNNFusePassTest
254
):
255
    def set_params(self):
256
        self.operand = paddle.subtract
257
        self.act = paddle.nn.functional.sigmoid
258 259


260 261
class ElementwiseActivationOneDNNFusePassTest_Mul_Relu(
    ElementwiseActivationOneDNNFusePassTest
262
):
263
    def set_params(self):
264
        self.operand = paddle.multiply
265
        self.act = F.relu
266 267


268 269
class ElementwiseActivationOneDNNFusePassTest_Mul_Tanh(
    ElementwiseActivationOneDNNFusePassTest
270
):
271
    def set_params(self):
272
        self.operand = paddle.multiply
273
        self.act = paddle.tanh
274 275


276 277
class ElementwiseActivationOneDNNFusePassTest_Mul_LeakyRelu(
    ElementwiseActivationOneDNNFusePassTest
278
):
279
    def set_params(self):
280
        self.operand = paddle.multiply
281
        self.act_alpha = 0.2
282
        self.act = paddle.nn.functional.leaky_relu
283 284


285 286
class ElementwiseActivationOneDNNFusePassTest_Mul_Swish(
    ElementwiseActivationOneDNNFusePassTest
287
):
288
    def set_params(self):
289
        self.operand = paddle.multiply
290
        self.act = paddle.nn.functional.swish
291 292


293 294
class ElementwiseActivationOneDNNFusePassTest_Mul_HardSwish(
    ElementwiseActivationOneDNNFusePassTest
295
):
296
    def set_params(self):
297
        self.operand = paddle.multiply
298
        self.act = paddle.nn.functional.hardswish
299 300


301 302
class ElementwiseActivationOneDNNFusePassTest_Mul_SQRT(
    ElementwiseActivationOneDNNFusePassTest
303
):
304
    def set_params(self):
305
        self.operand = paddle.multiply
306
        self.act = paddle.sqrt
307 308


309 310
class ElementwiseActivationOneDNNFusePassTest_Mul_ABS(
    ElementwiseActivationOneDNNFusePassTest
311
):
312
    def set_params(self):
313
        self.operand = paddle.multiply
314
        self.act = paddle.abs
315 316


317 318
class ElementwiseActivationOneDNNFusePassTest_Mul_Clip(
    ElementwiseActivationOneDNNFusePassTest
319
):
320
    def set_params(self):
321
        self.operand = paddle.multiply
322
        self.act = paddle.clip
323 324 325 326
        self.act_alpha = 0.0
        self.act_beta = 10.0


327 328
class ElementwiseActivationOneDNNFusePassTest_Mul_Gelu(
    ElementwiseActivationOneDNNFusePassTest
329
):
330
    def set_params(self):
331
        self.operand = paddle.multiply
332
        self.act = paddle.nn.functional.gelu
333 334


335 336
class ElementwiseActivationOneDNNFusePassTest_Mul_Gelu_Tanh(
    ElementwiseActivationOneDNNFusePassTest
337
):
338
    def set_params(self):
339
        self.operand = paddle.multiply
340
        self.act = paddle.nn.functional.gelu
341 342 343
        self.act_alpha = True


344 345
class ElementwiseActivationOneDNNFusePassTest_Mul_Relu6(
    ElementwiseActivationOneDNNFusePassTest
346
):
347
    def set_params(self):
348
        self.operand = paddle.multiply
349
        self.act = paddle.nn.functional.relu6
350 351


352 353
class ElementwiseActivationOneDNNFusePassTest_Mul_Sigmoid(
    ElementwiseActivationOneDNNFusePassTest
354
):
355
    def set_params(self):
356
        self.operand = paddle.multiply
357
        self.act = paddle.nn.functional.sigmoid
358 359


360
class ElementwiseScaleOneDNNFusePassTest_Add(
361
    ElementwiseActivationOneDNNFusePassTest
362 363
):
    def set_params(self):
364
        self.operand = paddle.add
365 366 367 368 369
        self.act_alpha = 0.6
        self.act = paddle.scale


class ElementwiseScaleOneDNNFusePassTest_Sub(
370
    ElementwiseActivationOneDNNFusePassTest
371 372
):
    def set_params(self):
373
        self.operand = paddle.subtract
374 375 376 377 378
        self.act_alpha = 0.6
        self.act = paddle.scale


class ElementwiseScaleOneDNNFusePassTest_Mul(
379
    ElementwiseActivationOneDNNFusePassTest
380 381
):
    def set_params(self):
382
        self.operand = paddle.multiply
383 384 385 386 387
        self.act_alpha = 0.6
        self.act = paddle.scale


class ElementwiseScaleOneDNNFusePassTest_Div(
388
    ElementwiseActivationOneDNNFusePassTest
389 390
):
    def set_params(self):
391
        self.operand = paddle.divide
392 393 394 395
        self.act_alpha = 0.6
        self.act = paddle.scale


396 397 398
if __name__ == "__main__":
    paddle.enable_static()
    unittest.main()