test_mkldnn_mish_op.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16 17
from functools import partial

18
import hypothesis.strategies as st
19 20 21 22
import numpy as np
from auto_scan_test import MkldnnAutoScanTest
from hypothesis import given
from program_config import OpConfig, ProgramConfig, TensorConfig
23 24 25 26 27


class TestMkldnnMishOp(MkldnnAutoScanTest):
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
        # if mode is channel, and in_shape is 1 rank
28 29 30 31
        if (
            len(program_config.inputs['input_data'].shape) == 1
            and program_config.ops[0].attrs['mode'] == 'channel'
        ):
32 33 34 35 36 37 38
            return False
        return True

    def sample_program_configs(self, *args, **kwargs):
        def generate_input(*args, **kwargs):
            return np.random.random(kwargs['in_shape']).astype(np.float32)

39 40 41 42 43 44 45 46 47
        mish_op = OpConfig(
            type="mish",
            inputs={"X": ["input_data"]},
            outputs={"Out": ["output_data"]},
            attrs={
                "mode": kwargs['mode'],
                "data_format": kwargs['data_format'],
            },
        )
48 49 50 51 52

        program_config = ProgramConfig(
            ops=[mish_op],
            weights={},
            inputs={
53 54 55
                "input_data": TensorConfig(
                    data_gen=partial(generate_input, *args, **kwargs)
                ),
56
            },
57 58
            outputs=["output_data"],
        )
59 60 61 62 63 64 65

        yield program_config

    def sample_predictor_configs(self, program_config):
        config = self.create_inference_config(use_mkldnn=True)
        yield config, (1e-5, 1e-5)

66 67 68 69 70 71 72
    @given(
        mode=st.sampled_from(['all', 'channel', 'element']),
        data_format=st.sampled_from(['NCHW', 'NHWC']),
        in_shape=st.lists(
            st.integers(min_value=1, max_value=32), min_size=1, max_size=4
        ),
    )
73 74 75 76 77 78
    def test(self, *args, **kwargs):
        self.run_test(quant=False, *args, **kwargs)


if __name__ == "__main__":
    unittest.main()