test_onednn_conv_bias_fuse_pass.py 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import hypothesis.strategies as st
18 19
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
20 21


22
class TestConvBiasOneDNNFusePass(PassAutoScanTest):
23 24 25
    def sample_predictor_configs(self, program_config):
        config = self.create_inference_config(use_gpu=False)
        config.enable_mkldnn()
26
        yield config, ['fused_conv2d'], (1e-4, 1e-5)
27 28

    def is_program_valid(self, prog_config):
29 30 31 32 33 34 35 36 37 38 39 40 41
        paddings = prog_config.ops[0].attrs['paddings']
        groups = prog_config.ops[0].attrs['groups']
        padding_algorithm = prog_config.ops[0].attrs['padding_algorithm']
        dilations = prog_config.ops[0].attrs['dilations']
        data_format = prog_config.ops[0].attrs['data_format']
        filter_shape = prog_config.weights['filter'].shape
        input_shape = prog_config.inputs['input_x'].shape
        height = input_shape[data_format.index('H')]
        width = input_shape[data_format.index('W')]
        if padding_algorithm == 'VALID':
            if (height - (dilations[0] * (filter_shape[2] - 1) + 1) <= 0) or (
                width - (dilations[1] * (filter_shape[3] - 1) + 1) <= 0
            ):
42
                return False
43
        if padding_algorithm == 'EXPLICIT':
44
            if (
45 46 47 48 49 50 51 52 53 54 55 56
                height
                + paddings[0]
                + paddings[1]
                - (dilations[0] * (filter_shape[2] - 1) + 1)
                <= 0
            ) or (
                width
                + paddings[2]
                + paddings[3]
                - (dilations[1] * (filter_shape[3] - 1) + 1)
                <= 0
            ):
57
                return False
58
        if data_format == 'NCHW':
59 60 61 62 63 64 65 66 67 68 69 70 71 72
            if input_shape[1] != filter_shape[1] * groups:
                return False
            if filter_shape[0] % groups != 0:
                return False
        else:
            if input_shape[3] != filter_shape[1] * groups:
                return False
            if filter_shape[0] % groups != 0:
                return False
        return True

    def sample_program_config(self, draw):
        # 1. Generate shape of input:X of conv2d
        x_shape = draw(
73 74 75 76
            st.lists(
                st.integers(min_value=5, max_value=100), min_size=4, max_size=4
            )
        )
77
        x_shape[1] = draw(st.integers(min_value=5, max_value=10))
78 79

        # 2. Generate legal attr:data_format of conv2d
80
        data_format = draw(st.sampled_from(['NCHW', 'NHWC']))
81 82 83

        # 3. Generate legal shape of input:Y of conv2d
        f_shape = draw(
84 85 86 87
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=4, max_size=4
            )
        )
88
        if data_format == 'NCHW':
89 90 91 92 93 94
            f_shape[1] = x_shape[1]
        else:
            f_shape[1] = x_shape[3]

        # 4. Generate legal attr:strides of conv2d
        strides = draw(
95 96 97 98
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=2, max_size=2
            )
        )
99 100

        # 5. Generate legal attr:padding_algorithm of conv2d
101
        padding_algorithm = draw(st.sampled_from(['EXPLICIT', 'SAME', 'VALID']))
102 103 104

        # 6. Generate legal attr:padding of conv2d
        padding = draw(
105 106 107 108
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=4, max_size=4
            )
        )
109 110 111 112 113 114

        # 7. Generate legal attr:groups of conv2d
        groups = draw(st.integers(min_value=1, max_value=3))

        # 8. Generate legal attr:dilations of conv2d
        dilations = draw(
115 116 117 118
            st.lists(
                st.integers(min_value=1, max_value=4), min_size=2, max_size=2
            )
        )
119 120 121 122 123 124

        # 9. Generate legal shape of input:bias of elementwise_add
        bias_shape = [f_shape[0]]

        # 10. Generate legal shape of attr:axis of elementwise_add
        axis = 1
125
        if data_format == 'NCHW':
126 127 128 129 130 131
            axis = 1
        else:
            axis = 3

        # 11. Generate legal shape of input:bias of conv2d
        conv_bias_shape = []
132 133
        inputs = {}
        weights = {}
134
        use_mkldnn = None
135
        conv_type = 'conv2d'
136 137
        if draw(st.booleans()):
            conv_bias_shape = [f_shape[0]]
138
            conv_type = 'fused_conv2d'
139
            inputs = {
140 141 142
                'Input': ['input_x'],
                'Filter': ['filter'],
                'Bias': ['conv_bias'],
143 144
            }
            weights = {
145 146 147
                'filter': TensorConfig(shape=f_shape),
                'bias': TensorConfig(shape=bias_shape),
                'conv_bias': TensorConfig(shape=conv_bias_shape),
148 149 150 151
            }
            use_mkldnn = True
        else:
            inputs = {
152 153
                'Input': ['input_x'],
                'Filter': ['filter'],
154 155
            }
            weights = {
156 157
                'filter': TensorConfig(shape=f_shape),
                'bias': TensorConfig(shape=bias_shape),
158 159 160
            }
            use_mkldnn = False

161
        conv2d_op = OpConfig(
162
            conv_type,
163
            inputs=inputs,
164
            outputs={'Output': ['conv2d_out']},
165 166 167 168 169 170 171 172 173 174
            strides=strides,
            padding_algorithm=padding_algorithm,
            paddings=padding,
            groups=groups,
            dilations=dilations,
            data_format=data_format,
            use_mkldnn=use_mkldnn,
        )

        add_op = OpConfig(
175 176 177
            'elementwise_add',
            inputs={'X': ['conv2d_out'], 'Y': ['bias']},
            outputs={'Out': ['add_out']},
178 179
            axis=axis,
        )
180 181 182 183 184 185

        ops = [conv2d_op, add_op]

        program_config = ProgramConfig(
            ops=ops,
            weights=weights,
186 187
            inputs={'input_x': TensorConfig(shape=x_shape)},
            outputs=ops[-1].outputs['Out'],
188
        )
189 190 191
        return program_config

    def test(self):
192
        self.run_and_statis(
193
            quant=False, passes=['conv_bias_mkldnn_fuse_pass'], max_examples=130
194
        )
195 196


197
if __name__ == '__main__':
198
    unittest.main()