test_mkldnn_int8_scale_calculation_pass.py 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import hypothesis.strategies as st
18 19
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37


class TestInt8ScaleCalculationMkldnnPass(PassAutoScanTest):
    def sample_predictor_configs(self, program_config):
        config = self.create_inference_config(use_gpu=False)
        config.pass_builder().append_pass("int8_scale_calculation_mkldnn_pass")
        yield config, ["conv2d"], (1e-4, 1e-5)

    def is_program_valid(self, prog_config):
        paddings = prog_config.ops[0].attrs["paddings"]
        strides = prog_config.ops[0].attrs["strides"]
        groups = prog_config.ops[0].attrs["groups"]
        padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
        dilations = prog_config.ops[0].attrs["dilations"]
        data_format = prog_config.ops[0].attrs["data_format"]
        filter_shape = prog_config.weights["filter"].shape
        input_shape = prog_config.inputs["input_x"].shape
        if padding_algorithm == "VALID":
38 39 40 41 42 43 44 45 46
            if (
                (input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1))
                / strides[0]
                + 1
            ) <= 1 or (
                (input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1))
                / strides[1]
                + 1
            ) <= 1:
47 48
                return False
        if padding_algorithm == "EXPLICIT":
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
            if (
                (
                    input_shape[2]
                    + paddings[0]
                    + paddings[1]
                    - (dilations[0] * (filter_shape[2] - 1) + 1)
                )
                / strides[0]
                + 1
            ) <= 1 or (
                (
                    input_shape[3]
                    + paddings[2]
                    + paddings[3]
                    - (dilations[1] * (filter_shape[3] - 1) + 1)
                )
                / strides[1]
                + 1
            ) <= 1:
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
                return False
        if data_format == "NCHW":
            if input_shape[1] != filter_shape[1] * groups:
                return False
            if filter_shape[0] % groups != 0:
                return False
        else:
            if input_shape[3] != filter_shape[1] * groups:
                return False
            if filter_shape[0] % groups != 0:
                return False
        return True

    def sample_program_config(self, draw):
        x_shape = draw(
83 84 85 86
            st.lists(
                st.integers(min_value=5, max_value=100), min_size=4, max_size=4
            )
        )
87 88 89 90 91
        x_shape[1] = draw(st.integers(min_value=5, max_value=10))

        data_format = draw(st.sampled_from(["NCHW", "NHWC"]))

        f_shape = draw(
92 93 94 95
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=4, max_size=4
            )
        )
96 97 98 99 100 101
        if data_format == "NCHW":
            f_shape[1] = x_shape[1]
        else:
            f_shape[1] = x_shape[3]

        strides = draw(
102 103 104 105
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=2, max_size=2
            )
        )
106 107 108 109

        padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))

        padding = draw(
110 111 112 113
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=4, max_size=4
            )
        )
114 115 116 117

        groups = draw(st.integers(min_value=1, max_value=3))

        dilations = draw(
118 119 120 121
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=2, max_size=2
            )
        )
122 123

        bias_shape = [f_shape[0]]
124 125
        inputs = {}
        weights = {}
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        use_mkldnn = True

        has_bias = draw(st.booleans())
        if has_bias:
            inputs = {
                "Input": ["input_x"],
                "Filter": ["filter"],
            }
            weights = {
                "filter": TensorConfig(shape=f_shape),
                "bias": TensorConfig(shape=bias_shape),
            }
        else:
            inputs = {
                "Input": ["input_x"],
                "Filter": ["filter"],
            }
143 144 145 146
            weights = {
                "filter": TensorConfig(shape=f_shape),
            }

147 148 149 150 151 152 153 154 155 156 157 158 159
        conv2d_op = OpConfig(
            "conv2d",
            inputs=inputs,
            outputs={"Output": ["conv2d_out"]},
            strides=strides,
            padding_algorithm=padding_algorithm,
            paddings=padding,
            groups=groups,
            dilations=dilations,
            data_format=data_format,
            use_mkldnn=use_mkldnn,
            mkldnn_data_type="int8",
        )
160 161 162 163 164 165 166

        ops = [conv2d_op]

        program_config = ProgramConfig(
            ops=ops,
            weights=weights,
            inputs={"input_x": TensorConfig(shape=x_shape)},
167 168
            outputs=["conv2d_out"],
        )
169 170 171
        return program_config

    def test(self):
172 173 174 175 176
        self.run_and_statis(
            quant=False,
            max_examples=100,
            passes=["int8_scale_calculation_mkldnn_pass"],
        )
177 178 179 180


if __name__ == "__main__":
    unittest.main()