From ac6969412c07a32e7ab5e6a48f66844f4e7b9d43 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Mon, 20 Dec 2021 21:11:49 +0800 Subject: [PATCH] add mkldnn conv_transpose_bias fuse pass ut (#37508) * add mkldnn conv_transpose_bias fuse pass ut * update conv_transpose_bias_mkldnn_fuse_pass ut * update conv_transpose_bias_mkldnn_fuse_pass ut * update conv_transpose_bias_mkldnn_fuse_pass ut * restrict conv2d data_format in conv_transpose_bias_mkldnn_fuse_pass * update ut timeout setting * update ut --- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 16 ++- .../unittests/ir/inference/CMakeLists.txt | 1 + ...st_mkldnn_conv_transpose_bias_fuse_pass.py | 117 ++++++++++++++++++ 3 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 5a16cda14eb..9f70c829e1f 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -115,7 +115,21 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .IsStringIn({"NCHW"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({1}) .End(); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index e2002765b1f..56bc97dea43 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -95,5 +95,6 @@ if (WITH_MKLDNN) set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 250) set_tests_properties(test_conv_transpose_bn_fuse_pass PROPERTIES TIMEOUT 300) + set_tests_properties(test_mkldnn_conv_transpose_bias_fuse_pass PROPERTIES TIMEOUT 100) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py new file mode 100644 index 00000000000..5df7cb8d8ce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestConvTransposeMkldnnFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # If the problem has been fixed, the judgment + # needs to be deleted!!! + if attrs[0]['data_format'] == "NHWC": + return False + + return True + + def sample_program_config(self, draw): + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.sampled_from([1, 2, 4, 8])) + paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) + strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + axis = draw(st.sampled_from([1])) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + def generate_input(): + if data_format == "NCHW": + return np.random.random( + [batch_size, 16, 64, 64]).astype(np.float32) + else: + return np.random.random( + [batch_size, 64, 64, 16]).astype(np.float32) + + def generate_weight1(): + return np.random.random([16, 16, 3, 3]).astype(np.float32) + + def generate_weight2(): + return np.random.random([16 * groups]).astype(np.float32) + + conv2d_op = OpConfig( + type="conv2d_transpose", + inputs={"Input": ["input_data"], + "Filter": ["conv2d_weight"]}, + outputs={"Output": ["conv_output"]}, + attrs={ + "data_format": data_format, + "dilations": dilations, + "padding_algorithm": padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides, + "output_size": [], + "output_padding": [], + "is_test": True + }) + + elt_op = OpConfig( + type="elementwise_add", + inputs={"X": ["conv_output"], + "Y": ["elementwise_weight"]}, + outputs={"Out": ["elementwise_output"]}, + attrs={'axis': axis}) + + model_net = [conv2d_op, elt_op] + + program_config = ProgramConfig( + ops=model_net, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial(generate_weight1)), + "elementwise_weight": + TensorConfig(data_gen=partial(generate_weight2)) + }, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)) + }, + outputs=["elementwise_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ['conv2d_transpose'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=["conv_transpose_bias_mkldnn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() -- GitLab