From 96597a85cdb115fbd419a6901f4ebb155267a052 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Thu, 16 Dec 2021 15:19:16 +0800 Subject: [PATCH] Add tests for PaddleInference Pass (#37676) * add test for conv_elementwise_add2_act_fuse_pass and conv_elementwise_add_act_fuse_pass * Add conv_eltwiseadd_bn_fuse_pass test and fix test_conv_elementwise_addX_act_fuse_pass * add tests for conv_act_mkldnn_fuse_pass * add test for conv_bias_mkldnn_fuse_pass * update code * add conv_act_mkldnn_fuse_pass for relu, relu6, swish, leaky_relu * update test * update * update bug * update * update pattern_detector * fix test_conv_eltwiseadd_bn_fuse_pass * add diff display notest;test=windows_ci_inference * fix * remove test_conv_act_mkldnn_fuse_pass.py * ifix --- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- .../framework/ir/graph_pattern_detector.cc | 1 + .../conv_activation_mkldnn_fuse_pass.cc | 3 + .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 1 + .../unittests/ir/inference/CMakeLists.txt | 4 +- .../unittests/ir/inference/auto_scan_test.py | 4 +- .../test_conv_bias_mkldnn_fuse_pass.py | 198 +++++++++++++ ...est_conv_elementwise_add2_act_fuse_pass.py | 255 ++++++++++++++--- ...test_conv_elementwise_add_act_fuse_pass.py | 211 +++++++++++--- .../test_conv_eltwiseadd_bn_fuse_pass.py | 265 ++++++++++++++++++ tools/parallel_UT_rule.py | 4 +- tools/static_mode_white_list.py | 0 12 files changed, 865 insertions(+), 83 deletions(-) mode change 100644 => 100755 paddle/fluid/framework/ir/CMakeLists.txt mode change 100644 => 100755 paddle/fluid/framework/ir/graph_pattern_detector.cc mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc mode change 100644 => 100755 paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc mode change 100644 => 100755 python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt mode change 100644 => 100755 python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py create mode 100755 python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py mode change 100644 => 100755 python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py mode change 100644 => 100755 python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_act_fuse_pass.py create mode 100755 python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_bn_fuse_pass.py mode change 100644 => 100755 tools/parallel_UT_rule.py mode change 100644 => 100755 tools/static_mode_white_list.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt old mode 100644 new mode 100755 index da7ab44c21c..a34c2e9aa87 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -173,7 +173,7 @@ if(NOT WIN32) endif() if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) - cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) + cc_test(test_conv_bias_mkldnn_fuse_pass_cc SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc old mode 100644 new mode 100755 index b7cba781007..6a5bca7bde4 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -807,6 +807,7 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, // Bias eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr()) ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var() ->AsInput(); eltwise_out_var = pattern->NewNode(eltwise_out_repr()) ->AsIntermediate() diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc old mode 100644 new mode 100755 index cfd40435387..8255a40a2c0 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -201,6 +201,9 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() { .End() .AddOutput("Out") .IsTensor() + .End() + .AddAttr("beta") + .IsType() .End(); } Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc old mode 100644 new mode 100755 index aae1da5f0a3..5a16cda14eb --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -239,6 +239,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { auto input_names = conv->Op()->InputNames(); bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") != input_names.end(); + if (has_bias && conv->Op()->Input("Bias").size() > 0) { auto conv_bias_names = conv->Op()->Input("Bias"); // add eltwise bias to existing conv bias diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt old mode 100644 new mode 100755 index 6428ca1e4ac..25fc1f32603 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -80,7 +80,9 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_simplify_with_basic_ops_pass_autoscan PROPERTIES TIMEOUT 60) set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan PROPERTIES TIMEOUT 60) - set_tests_properties(test_conv_eltwiseadd_affine_channel_fuse_pass PROPERTIES TIMEOUT 100) + set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_conv_elementwise_add2_act_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 120) endif() if (WITH_MKLDNN) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py old mode 100644 new mode 100755 index c05ad30da27..fa09ef19977 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -144,10 +144,12 @@ class AutoScanTest(unittest.TestCase): baseline[key].shape == arr.shape, "The output shapes are not equal, the baseline shape is " + str(baseline[key].shape) + ', but got ' + str(arr.shape)) + diff = abs(baseline[key] - arr) self.assertTrue( np.allclose( baseline[key], arr, atol=atol, rtol=rtol), - "Output has diff. ") + "Output has diff, Maximum absolute error: {}".format( + np.amax(diff))) @abc.abstractmethod def run_test(self, quant=False): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py new file mode 100755 index 00000000000..40fd9a418b9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestConvBiasMkldnnFusePass(PassAutoScanTest): + """ + x_var f_var(persistable) + \ / + conv2d + | + conv2d_var bias_var(persistable) + \ / + elementwise_add + | + elementwise_add_var + """ + + def sample_predictor_configs(self, program_config): + # MKLDNN + config = self.create_inference_config(use_gpu=False) + config.enable_mkldnn() + yield config, ["conv2d"], (1e-4, 1e-5) + + def is_program_valid(self, prog_config): + paddings = prog_config.ops[0].attrs["paddings"] + strides = prog_config.ops[0].attrs["strides"] + groups = prog_config.ops[0].attrs["groups"] + padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"] + dilations = prog_config.ops[0].attrs["dilations"] + data_format = prog_config.ops[0].attrs["data_format"] + filter_shape = prog_config.weights["filter"].shape + input_shape = prog_config.inputs["input_x"].shape + if data_format != "NCHW": + return False + if padding_algorithm == "VALID": + if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if padding_algorithm == "EXPLICIT": + if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if data_format == "NCHW": + if input_shape[1] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + else: + if input_shape[3] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + return True + + def sample_program_config(self, draw): + # 1. Generate shape of input:X of conv2d + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + + # 2. Generate legal attr:data_format of conv2d + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + + # 3. Generate legal shape of input:Y of conv2d + f_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=7), min_size=4, max_size=4)) + if data_format == "NCHW": + f_shape[1] = x_shape[1] + else: + f_shape[1] = x_shape[3] + + # 4. Generate legal attr:strides of conv2d + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=2, max_size=2)) + + # 5. Generate legal attr:padding_algorithm of conv2d + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + + # 6. Generate legal attr:padding of conv2d + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=4, max_size=4)) + + # 7. Generate legal attr:groups of conv2d + groups = draw(st.integers(min_value=1, max_value=3)) + + # 8. Generate legal attr:dilations of conv2d + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=2, max_size=2)) + + # 9. Generate legal shape of input:bias of elementwise_add + bias_shape = [f_shape[0]] + + # 10. Generate legal shape of attr:axis of elementwise_add + axis = 1 + if data_format == "NCHW": + axis = 1 + else: + axis = 3 + + # 11. Generate legal shape of input:bias of conv2d + conv_bias_shape = [] + inputs = dict() + weights = dict() + use_mkldnn = None + if draw(st.booleans()): + conv_bias_shape = [f_shape[0]] + inputs = { + "Input": ["input_x"], + "Filter": ["filter"], + "Bias": ["conv_bias"], + } + weights = { + "filter": TensorConfig(shape=f_shape), + "bias": TensorConfig(shape=bias_shape), + "conv_bias": TensorConfig(shape=conv_bias_shape) + } + use_mkldnn = True + else: + inputs = { + "Input": ["input_x"], + "Filter": ["filter"], + } + weights = { + "filter": TensorConfig(shape=f_shape), + "bias": TensorConfig(shape=bias_shape) + } + use_mkldnn = False + + conv2d_op = OpConfig( + "conv2d", + inputs=inputs, + outputs={"Output": ["conv2d_out"]}, + strides=strides, + padding_algorithm=padding_algorithm, + paddings=padding, + groups=groups, + dilations=dilations, + data_format=data_format, + use_mkldnn=use_mkldnn) + + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["conv2d_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis) + + ops = [conv2d_op, add_op] + + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs={"input_x": TensorConfig(shape=x_shape)}, + outputs=ops[-1].outputs["Out"]) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=350, + passes=["conv_bias_mkldnn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py old mode 100644 new mode 100755 index 6907b6a7eb5..9dd41bd1c39 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,44 +12,223 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import PassVersionChecker -from paddle.fluid.core import AnalysisConfig -"""Test for fusion of conv, elementwise_add and 2 act.""" - - -class ConvElementwiseAdd2ActFusePassTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 100, 100], dtype="float32") - add_y2 = fluid.data( - name="add_y2", shape=[1, 3, 98, 98], dtype="float32") - conv_out = fluid.layers.conv2d( - input=data, num_filters=3, filter_size=3, bias_attr=None) - add1_out = fluid.layers.elementwise_add( - add_y2, conv_out, act="relu") - - self.feeds = { - "data": np.random.random((1, 3, 100, 100)).astype("float32"), - "add_y2": np.random.random((1, 3, 98, 98)).astype("float32") - } - self.fetch_list = [add1_out] - self.enable_mkldnn = False - - def test_check_output(self): - if core.is_compiled_with_cuda(): - use_gpu = True - self.check_output_with_option(use_gpu) - self.assertTrue( - PassVersionChecker.IsCompatible( - 'conv_elementwise_add2_act_fuse_pass')) +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestConvElementwiseAdd2ActPass(PassAutoScanTest): + """ + x_var f_var(persistable) + \ / + conv2d + | + conv2d_var y_var(persistable) + \ / + elementwise_add + | + x1_var elementwise_add_out_var + \ / + elementwise_add + | + act + | + act_var + """ + + def sample_predictor_configs(self, program_config): + # for gpu + config = self.create_inference_config(use_gpu=True) + yield config, ["conv2d_fusion"], (1e-4, 1e-5) + + def is_program_valid(self, prog_config): + paddings = prog_config.ops[0].attrs["paddings"] + strides = prog_config.ops[0].attrs["strides"] + groups = prog_config.ops[0].attrs["groups"] + padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"] + dilations = prog_config.ops[0].attrs["dilations"] + data_format = prog_config.ops[0].attrs["data_format"] + filter_shape = prog_config.weights["filter"].shape + input_shape = prog_config.inputs["input_x"].shape + if data_format != "NCHW": + return False + if padding_algorithm == "VALID": + if int(((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1)) <= 0 or \ + int(((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1)) <= 0: + return False + if padding_algorithm == "EXPLICIT": + if int(((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1)) <= 0 or \ + int(((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1)) <= 0: + return False + if padding_algorithm == "SAME": + if int((input_shape[2] + strides[0] - 1) / strides[0]) <= 0 or int( + (input_shape[3] + strides[1] - 1) / strides[1]) <= 0: + return False + if data_format == "NCHW": + if input_shape[1] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + else: + if input_shape[3] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + + return True + + def sample_program_config(self, draw): + + is_not_valid = True + program_config = None + while is_not_valid: + # 1. Generate shape of input:X of conv2d + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=100), + min_size=4, + max_size=4)) + x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + + # 2. Generate legal attr:data_format of conv2d + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + + # 3. Generate legal shape of input:Y of conv2d + f_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=7), + min_size=4, + max_size=4)) + if data_format == "NCHW": + f_shape[1] = x_shape[1] + else: + f_shape[1] = x_shape[3] + + # 4. Generate legal attr:strides of conv2d + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=2, + max_size=2)) + + # 5. Generate legal attr:padding_algorithm of conv2d + padding_algorithm = draw( + st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + + # 6. Generate legal attr:padding of conv2d + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=4, + max_size=4)) + + # 7. Generate legal attr:groups of conv2d + groups = draw(st.integers(min_value=1, max_value=3)) + + # 8. Generate legal attr:dilations of conv2d + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=2, + max_size=2)) + + # 9. Generate legal elemntwise_add: X of conv2d + bias_2_dict = dict() + bias_2_dict[1] = [x_shape[0], f_shape[0], \ + int(((x_shape[2] + padding[0] + padding[1] - (dilations[0] * (f_shape[2] - 1) + 1)) / strides[0] + 1)), \ + int(((x_shape[3] + padding[2] + padding[3] - (dilations[1] * (f_shape[3] - 1) + 1)) / strides[1] + 1))] + + bias_2_dict[2] = [x_shape[0], f_shape[0], \ + int((x_shape[2] + strides[0] - 1) / strides[0]), \ + int((x_shape[3] + strides[1] - 1) / strides[1])] + + bias_2_dict[3] = [x_shape[0], f_shape[0], \ + int(((x_shape[2] - (dilations[0] * (f_shape[2] - 1) + 1)) / strides[0] + 1)), \ + int(((x_shape[3] - (dilations[1] * (f_shape[3] - 1) + 1)) / strides[1] + 1))] + bias_index = 1 + if padding_algorithm == "SAME": + bias_index = 2 + if padding_algorithm == "VALID": + bias_index = 3 + bias_2_shape = bias_2_dict[bias_index] + + if np.sum(np.array(bias_2_shape) <= 0) == 0: + is_not_valid = False + else: + continue + + # 10. Generate legal shape of input:bias of elementwise_add + bias_shape = [f_shape[0]] + + # 11. Generate legal attr:axis of elementwise_add_1 + axis_1 = 1 + + # 12. Generate legal attr:axis of elementwise_add_2 + axis_2 = -1 + + conv2d_op = OpConfig( + "conv2d", + inputs={"Input": ["input_x"], + "Filter": ["filter"]}, + outputs={"Output": ["conv2d_out"]}, + strides=strides, + padding_algorithm=padding_algorithm, + paddings=padding, + groups=groups, + dilations=dilations, + data_format=data_format) + add_1_op = OpConfig( + "elementwise_add", + inputs={"X": ["conv2d_out"], + "Y": ["bias_1"]}, + outputs={"Out": ["add_1_out"]}, + axis=axis_1) + + add_2_op = OpConfig( + "elementwise_add", + inputs={"X": ["bias_2"], + "Y": ["add_1_out"]}, + outputs={"Out": ["add_out"]}, + axis=axis_2) + + relu_op = OpConfig( + "relu", + inputs={"X": ["add_out"]}, + outputs={"Out": ["relu_out"]}) + + ops = [conv2d_op, add_1_op, add_2_op, relu_op] + + program_config = ProgramConfig( + ops=ops, + weights={ + "filter": TensorConfig(shape=f_shape), + "bias_1": TensorConfig(shape=bias_shape), + }, + inputs={ + "input_x": TensorConfig(shape=x_shape), + "bias_2": TensorConfig(shape=bias_2_shape) + }, + outputs=ops[-1].outputs["Out"], ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=300, + passes=["conv_elementwise_add2_act_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_act_fuse_pass.py old mode 100644 new mode 100755 index 6ff60aa6deb..0d93ae9a7d2 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_act_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_act_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,46 +12,177 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import PassVersionChecker -from paddle.fluid.core import AnalysisConfig -"""Test for fusion of conv, elementwise_add and act.""" - - -class ConvElementwiseAddActFusePassTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 100, 100], dtype="float32") - param_attr = fluid.ParamAttr( - initializer=fluid.initializer.Xavier(uniform=False), - learning_rate=0.001) - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - bias_attr=param_attr, - act="relu") - - self.feeds = { - "data": np.random.random((1, 3, 100, 100)).astype("float32") - } - self.fetch_list = [conv_out] - self.enable_mkldnn = False - - def test_check_output(self): - if core.is_compiled_with_cuda(): - use_gpu = True - self.check_output_with_option(use_gpu) - self.assertTrue( - PassVersionChecker.IsCompatible( - 'conv_elementwise_add_act_fuse_pass')) +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestConvElementwiseAddActPass(PassAutoScanTest): + """ + x_var f_var(persistable) + \ / + conv2d + | + conv2d_var y_var(persistable) + \ / + elementwise_add + | + elementwise_add_var + | + act + | + act_var + """ + + def sample_predictor_configs(self, program_config): + # for gpu + config = self.create_inference_config(use_gpu=True) + yield config, ["conv2d_fusion"], (1e-4, 1e-5) + + def is_program_valid(self, prog_config): + paddings = prog_config.ops[0].attrs["paddings"] + strides = prog_config.ops[0].attrs["strides"] + groups = prog_config.ops[0].attrs["groups"] + padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"] + dilations = prog_config.ops[0].attrs["dilations"] + data_format = prog_config.ops[0].attrs["data_format"] + filter_shape = prog_config.weights["filter"].shape + input_shape = prog_config.inputs["input_x"].shape + if data_format != "NCHW": + return False + if padding_algorithm == "VALID": + if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if padding_algorithm == "EXPLICIT": + if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if data_format == "NCHW": + if input_shape[1] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + else: + if input_shape[3] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + return True + + def sample_program_config(self, draw): + # 1. Generate shape of input:X of conv2d + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + + # 2. Generate legal attr:data_format of conv2d + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + + # 3. Generate legal shape of input:Y of conv2d + f_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=7), min_size=4, max_size=4)) + if data_format == "NCHW": + f_shape[1] = x_shape[1] + else: + f_shape[1] = x_shape[3] + + # 4. Generate legal attr:strides of conv2d + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=2, max_size=2)) + + # 5. Generate legal attr:padding_algorithm of conv2d + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + + # 6. Generate legal attr:padding of conv2d + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=4, max_size=4)) + + # 7. Generate legal attr:groups of conv2d + groups = draw(st.integers(min_value=1, max_value=3)) + + # 8. Generate legal attr:dilations of conv2d + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=2, max_size=2)) + + # 9. Generate legal input:ResidualData of conv2d + res_shape = [] + if draw(st.booleans()): + res_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=100), + min_size=4, + max_size=4)) + + # 10. Generate legal shape of input:bias of elementwise_add + bias_shape = [f_shape[0]] + + # 11. Generate legal attr:axis of elementwise_add + axis = 1 + + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["input_x"], + "Filter": ["filter"], + "ResidualData": ["residualdata"] + }, + outputs={"Output": ["conv2d_out"]}, + strides=strides, + padding_algorithm=padding_algorithm, + paddings=padding, + groups=groups, + dilations=dilations, + data_format=data_format) + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["conv2d_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis) + + relu_op = OpConfig( + "relu", inputs={"X": ["add_out"]}, outputs={"Out": ["relu_out"]}) + + ops = [conv2d_op, add_op, relu_op] + + program_config = ProgramConfig( + ops=ops, + weights={ + "filter": TensorConfig(shape=f_shape), + "bias": TensorConfig(shape=bias_shape), + }, + inputs={ + "input_x": TensorConfig(shape=x_shape), + "residualdata": TensorConfig(shape=res_shape) + }, + outputs=ops[-1].outputs["Out"], ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=400, + passes=["conv_elementwise_add_act_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_bn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_bn_fuse_pass.py new file mode 100755 index 00000000000..c8319a5f3d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_bn_fuse_pass.py @@ -0,0 +1,265 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestConvEltwiseaddBnFusePass(PassAutoScanTest): + """ + x_var f_var(persistable) + \ / + conv2d + | + conv2d_var bias_var(persistable) + \ / + elementwise_add + | + elementwise_add_var Scale(persistable) Bias(persistable) Mean(persistable) Variance(persistable) + | + batch_norm + | + Y MeanOut VarianceOut SavedMeanSavedVariance + """ + + def sample_predictor_configs(self, program_config): + # cpu + config = self.create_inference_config(use_gpu=False) + yield config, ["conv2d", "elementwise_add"], (1e-4, 1e-5) + + # MKLDNN + config = self.create_inference_config(use_gpu=False) + config.enable_mkldnn() + yield config, ["conv2d", "elementwise_add"], (1e-4, 1e-5) + + # for gpu + config = self.create_inference_config(use_gpu=True) + yield config, ["conv2d", "elementwise_add"], (1e-4, 1e-5) + + def is_program_valid(self, prog_config): + paddings = prog_config.ops[0].attrs["paddings"] + strides = prog_config.ops[0].attrs["strides"] + groups = prog_config.ops[0].attrs["groups"] + padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"] + dilations = prog_config.ops[0].attrs["dilations"] + data_format = prog_config.ops[0].attrs["data_format"] + filter_shape = prog_config.weights["filter"].shape + input_shape = prog_config.inputs["input_x"].shape + if data_format != "NCHW": + return False + if padding_algorithm == "VALID": + if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if padding_algorithm == "EXPLICIT": + if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + + if data_format == "NCHW": + if input_shape[1] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + else: + if input_shape[3] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + + bn_scale = np.array(prog_config.weights["scale_in"].data) + bn_bias = np.array(prog_config.weights["bias_in"].data) + bn_mean = np.array(prog_config.weights["mean_in"].data) + bn_variance = np.array(prog_config.weights["variance_in"].data) + epsilon = np.array(prog_config.ops[-1].attrs["epsilon"]) + bn_variance = bn_variance + epsilon + + if np.isnan(bn_variance).any(): + return False + bn_variance = np.sqrt(bn_variance) + if np.sum(bn_variance == 0.0) > 0: + return False + bn_variance = bn_scale / bn_variance + if np.isnan(bn_variance).any(): + return False + return True + + def sample_program_config(self, draw): + # 1. Generate shape of input:X of conv2d + x_shape = draw( + st.lists( + st.integers( + min_value=10, max_value=100), + min_size=4, + max_size=4)) + x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + + # 2. Generate legal attr:data_format of conv2d + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + + # 2. Generate legal shape of input:Y of conv2d + f_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=7), min_size=4, max_size=4)) + if data_format == "NCHW": + f_shape[1] = x_shape[1] + else: + f_shape[1] = x_shape[3] + + # 3. Generate legal attr:strides of conv2d + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=2, max_size=2)) + + # 4. Generate legal attr:padding_algorithm of conv2d + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + + # 5. Generate legal attr:padding of conv2d + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=4, max_size=4)) + + # 6. Generate legal attr:groups of conv2d + groups = draw(st.integers(min_value=1, max_value=3)) + + # 7. Generate legal attr:dilations of conv2d + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=5), min_size=2, max_size=2)) + + # 9. Generate legal input:ResidualData of conv2d + res_shape = [] + if draw(st.booleans()): + res_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=100), + min_size=4, + max_size=4)) + + # 10. Generate legal shape of input:bias of elementwise_add + bias_shape = [f_shape[0]] + + # 11. Generate legal attr:axis of elementwise_add + axis = 1 + + # 12. Generate legal input:Scale of batch_norm + bn_scale_shape = [f_shape[0]] + + # 13. Generate legal input:Bias of batch_norm + bn_bias_shape = [f_shape[0]] + + # 14. Generate legal input:Mean of batch_norm + bn_mean_shape = [f_shape[0]] + + # 15. Generate legal input:Variance of batch_norm + bn_variance_shape = [f_shape[0]] + + # 16. Generate legal attr:epsilon of batch_norm + epsilon = draw(st.floats(min_value=0.00001, max_value=0.001)) + + def generate_batch_variance(): + return (0.1 + (1.0 - 0.1) * np.random.random(bn_variance_shape) + ).astype(np.float32) + + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["input_x"], + "Filter": ["filter"], + "ResidualData": ["residualdata"] + }, + outputs={"Output": ["conv2d_out"]}, + strides=strides, + padding_algorithm=padding_algorithm, + paddings=padding, + groups=groups, + dilations=dilations, + data_format=data_format) + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["conv2d_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis) + + bn_op = OpConfig( + "batch_norm", + inputs={ + "X": ["add_out"], + "Scale": ["scale_in"], + "Bias": ["bias_in"], + "Mean": ["mean_in"], + "Variance": ["variance_in"] + }, + outputs={ + "Y": ["y_out"], + "MeanOut": ["mean_in"], + "VarianceOut": ["variance_in"], + "SavedMean": ["SavedMean_out"], + "SavedVariance": ["SavedVariance_out"], + "ReserveSpace": ["ReserveSpace_out"] + }, + epsilon=epsilon, + is_test=True, + trainable_statistics=False, + data_layout=data_format) + + ops = [conv2d_op, add_op, bn_op] + + # 17. if the output of bias is more than one + if draw(st.booleans()): + outputs = ops[-1].outputs["Y"] + else: + outputs = ops[-1].outputs["Y"] + ["bias"] + + program_config = ProgramConfig( + ops=ops, + weights={ + "filter": TensorConfig(shape=f_shape), + "bias": TensorConfig(shape=bias_shape), + "scale_in": TensorConfig(shape=bn_scale_shape), + "bias_in": TensorConfig(shape=bn_bias_shape), + "mean_in": TensorConfig(shape=bn_mean_shape), + "variance_in": TensorConfig(data_gen=generate_batch_variance), + }, + inputs={ + "input_x": TensorConfig(shape=x_shape), + "residualdata": TensorConfig(shape=res_shape) + }, + outputs=outputs) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=300, + passes=["conv_eltwiseadd_bn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py old mode 100644 new mode 100755 index dba411daade..e302d967983 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -431,7 +431,7 @@ HIGH_PARALLEL_JOB_NEW = [ 'test_memory_usage', 'test_sysconfig', 'reader_test', - 'test_conv_bias_mkldnn_fuse_pass', + 'test_conv_bias_mkldnn_fuse_pass_cc', 'math_function_test', 'beam_search_decode_op_test', 'save_quant2_model_resnet50', @@ -1469,7 +1469,7 @@ CPU_PARALLEL_JOB = [ 'test_cpu_bfloat16_placement_pass', 'test_cpu_bfloat16_pass', 'test_conv_concat_relu_mkldnn_fuse_pass', - 'test_conv_bias_mkldnn_fuse_pass', + 'test_conv_bias_mkldnn_fuse_pass_cc', 'test_conv_batch_norm_mkldnn_fuse_pass', 'test_conv3d_transpose_layer', 'test_conv3d_mkldnn_op', diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py old mode 100644 new mode 100755 -- GitLab