From a858326ac310f68e22c08e4c1ddbeacc3e5006a0 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Fri, 24 Dec 2021 10:41:11 +0800 Subject: [PATCH] add conv+hard_sigmoid and conv+hard_swish fuse pass ut (#37553) * add conv+hard_sigmoid fuse pass ut * update conv_elementwise_add_mkldnn_fuse_pass ut * update conv_hard_sigmoid_mkldnn_fuse_pass ut * update conv+hard_sigmoid and conv+hard_swish fuse pass ut * update ut * update ut --- .../conv_activation_mkldnn_fuse_pass.cc | 2 +- .../unittests/ir/inference/CMakeLists.txt | 2 + ...test_mkldnn_conv_hard_sigmoid_fuse_pass.py | 119 +++++++++++++++++ .../test_mkldnn_conv_hard_swish_fuse_pass.py | 121 ++++++++++++++++++ 4 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 8255a40a2c..0d0151fb73 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() { // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute .AddAttr("data_format") .IsOptional() - .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("relu")) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 055cb8ff91..e69328f5fc 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -98,6 +98,8 @@ if (WITH_MKLDNN) set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 250) set_tests_properties(test_conv_transpose_bn_fuse_pass PROPERTIES TIMEOUT 300) + set_tests_properties(test_mkldnn_conv_hard_sigmoid_fuse_pass PROPERTIES TIMEOUT 300) + set_tests_properties(test_mkldnn_conv_hard_swish_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_batch_norm_act_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_mkldnn_conv_transpose_bias_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py new file mode 100644 index 0000000000..a0c4e18393 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + # If the problem has been fixed, the judgment + # needs to be deleted!!! + if attrs[0]['data_format'] == "NHWC": + return False + + return True + + def sample_program_config(self, draw): + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.sampled_from([1, 2, 4])) + paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) + strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + slope = draw(st.floats(min_value=0, max_value=10)) + offset = draw(st.floats(min_value=0, max_value=10)) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + def generate_input(): + if data_format == "NCHW": + return np.random.random( + [batch_size, 48, 64, 64]).astype(np.float32) + else: + return np.random.random( + [batch_size, 64, 64, 48]).astype(np.float32) + + def generate_weight(): + return np.random.random( + [16, int(48 / groups), 3, 3]).astype(np.float32) + + ops_config = [{ + "op_type": "conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["input_weight"] + }, + "op_outputs": { + "Output": ["conv_output"] + }, + "op_attrs": { + "data_format": data_format, + "dilations": dilations, + "padding_algorithm": padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides + } + }, { + "op_type": "hard_sigmoid", + "op_inputs": { + "X": ["conv_output"] + }, + "op_outputs": { + "Out": ["sigmoid_output"] + }, + "op_attrs": { + "slope": slope, + "offset": offset + }, + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "input_weight": TensorConfig(data_gen=partial(generate_weight)) + }, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + outputs=["sigmoid_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ["conv2d"], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=["conv_hard_sigmoid_mkldnn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py new file mode 100644 index 0000000000..17bfb625fd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + # If the problem has been fixed, the judgment + # needs to be deleted!!! + if attrs[0]['data_format'] == "NHWC": + return False + + return True + + def sample_program_config(self, draw): + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.sampled_from([1, 2, 4])) + paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) + strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + threshold = draw(st.sampled_from([6.0])) + scale = draw(st.sampled_from([6.0])) + offset = draw(st.sampled_from([3.0])) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + def generate_input(): + if data_format == "NCHW": + return np.random.random( + [batch_size, 48, 64, 64]).astype(np.float32) + else: + return np.random.random( + [batch_size, 64, 64, 48]).astype(np.float32) + + def generate_weight(): + return np.random.random( + [16, int(48 / groups), 3, 3]).astype(np.float32) + + ops_config = [{ + "op_type": "conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["input_weight"] + }, + "op_outputs": { + "Output": ["conv_output"] + }, + "op_attrs": { + "data_format": data_format, + "dilations": dilations, + "padding_algorithm": padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides + } + }, { + "op_type": "hard_swish", + "op_inputs": { + "X": ["conv_output"] + }, + "op_outputs": { + "Out": ["swish_output"] + }, + "op_attrs": { + "threshold": threshold, + "scale": scale, + "offset": offset + }, + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "input_weight": TensorConfig(data_gen=partial(generate_weight)) + }, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + outputs=["swish_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ["conv2d"], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=["conv_hard_swish_mkldnn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() -- GitLab