diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index 3875d856d20bd6b0d8047e5f7d876c17d6f8b040..af6773042b67870d715ed37894e0321015485d0d 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -130,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .End(); AddOpCompat(OpCompat("affine_channel")) @@ -148,7 +148,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsTensor() .End() .AddAttr("data_layout") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -197,19 +197,23 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { GET_CONV_BN_NODES(conv_ac_pattern); + // Get affine_channel bias for resizing eltwise_y! + auto* ac_bias_tensor = + scope->FindVar(ac_bias->Name())->GetMutable(); + // Create eltwise_y (conv bias) variable VarDesc eltwise_y_in_desc( patterns::PDNodeName(name_scope_, "eltwise_y_in")); + // Set shape && datatype manually + eltwise_y_in_desc.SetShape(framework::vectorize(ac_bias_tensor->dims())); + eltwise_y_in_desc.SetDataType(ac_bias_tensor->type()); + eltwise_y_in_desc.SetLoDLevel(ac_bias->Var()->GetLoDLevel()); eltwise_y_in_desc.SetPersistable(true); + + // Initialize eltwise_y auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); auto* eltwise_y_in_tensor = scope->Var(eltwise_y_in_node->Name())->GetMutable(); - - // Get affine_channel bias - auto* ac_bias_tensor = - scope->FindVar(ac_bias->Name())->GetMutable(); - - // Initialize eltwise_y eltwise_y_in_tensor->Resize(ac_bias_tensor->dims()); std::fill_n(eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), eltwise_y_in_tensor->numel(), 0.0f); @@ -278,7 +282,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .End(); AddOpCompat(OpCompat("affine_channel")) .AddInput("X") @@ -295,7 +299,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsTensor() .End() .AddAttr("data_layout") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .End(); AddOpCompat(OpCompat("elementwise_add")) .AddInput("X") diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 248a71ede14beb35db0580b879891d5b3b614157..439b85ffb9f10dd9e50aab6353cf0df33cc6f166 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() { .AddAttr("dilations") .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -87,7 +87,7 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { patterns::ConvElementwiseadd pattern(gpd.mutable_pattern(), pattern_name); pattern(x); - + int found_conv_eltwise_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { if (!IsCompat(subgraph, g)) { @@ -135,9 +135,12 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { // Delete the unneeded nodes. GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op}); + found_conv_eltwise_count++; }; gpd(graph, handler); + // check if detect conv2d_fusion subgraph! + AddStatis(found_conv_eltwise_count); } } // namespace ir diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 67d300fe186a8b82c013e197a72160e16fff912e..0d7299fa989ed3b866a5f3a66650a4a58b38f6c6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -80,6 +80,7 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_simplify_with_basic_ops_pass_autoscan PROPERTIES TIMEOUT 60) set_tests_properties(test_adaptive_pool2d_convert_global_pass_autoscan PROPERTIES TIMEOUT 60) + set_tests_properties(test_conv_eltwiseadd_affine_channel_fuse_pass PROPERTIES TIMEOUT 100) endif() if (WITH_MKLDNN) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_affine_channel_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_affine_channel_fuse_pass.py index ec0bd52e9261017335f0bf424d32f26d4c465029..5afaf08eec3b1324df312920bd9e8c8970fd7dbc 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_affine_channel_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_affine_channel_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,216 +12,148 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import PassVersionChecker - - -class ConvAffineChannelFusePassExplicitPaddingTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 64, 64], dtype="float32") - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - groups=3, - padding=[1, 1, 1, 1], - bias_attr=False, - act=None) - input_scale = fluid.layers.create_parameter( - shape=[3], dtype="float32") - input_bias = fluid.layers.create_parameter( - shape=[3], dtype="float32") - ac_out = fluid.layers.affine_channel( - x=conv_out, scale=input_scale, bias=input_bias) - - self.feeds = { - "data": np.random.random([1, 3, 64, 64]).astype("float32"), - } - self.fetch_list = [ac_out] - - def test_check_output(self): - self.check_output() - - self.assertTrue( - PassVersionChecker.IsCompatible('conv_affine_channel_fuse_pass')) - - -class ConvAffineChannelFusePassValidPaddingTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 64, 64], dtype="float32") - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - groups=3, - padding='VALID', - bias_attr=False, - act=None) - input_scale = fluid.layers.create_parameter( - shape=[3], dtype="float32") - input_bias = fluid.layers.create_parameter( - shape=[3], dtype="float32") - ac_out = fluid.layers.affine_channel( - x=conv_out, scale=input_scale, bias=input_bias) - - self.feeds = { - "data": np.random.random([1, 3, 64, 64]).astype("float32"), - } - self.fetch_list = [ac_out] - - def test_check_output(self): - self.check_output() - - self.assertTrue( - PassVersionChecker.IsCompatible('conv_affine_channel_fuse_pass')) - - -class ConvAffineChannelFusePassSamePaddingTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 64, 64], dtype="float32") - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - groups=3, - padding='SAME', - bias_attr=False, - act=None) - input_scale = fluid.layers.create_parameter( - shape=[3], dtype="float32") - input_bias = fluid.layers.create_parameter( - shape=[3], dtype="float32") - ac_out = fluid.layers.affine_channel( - x=conv_out, scale=input_scale, bias=input_bias) - - self.feeds = { - "data": np.random.random([1, 3, 64, 64]).astype("float32"), - } - self.fetch_list = [ac_out] - - def test_check_output(self): - self.check_output() - - self.assertTrue( - PassVersionChecker.IsCompatible('conv_affine_channel_fuse_pass')) - - -class ConvEltwiseAddAffineChannelFusePassExplicitPaddingTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 64, 64], dtype="float32") - param_attr = fluid.ParamAttr( - initializer=fluid.initializer.Xavier(uniform=False), - learning_rate=0.001) - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - groups=3, - padding=[1, 1, 1, 1], - bias_attr=param_attr, - act=None) - input_scale = fluid.layers.create_parameter( - shape=[3], dtype="float32") - input_bias = fluid.layers.create_parameter( - shape=[3], dtype="float32") - ac_out = fluid.layers.affine_channel( - x=conv_out, scale=input_scale, bias=input_bias) - - self.feeds = { - "data": np.random.random([1, 3, 64, 64]).astype("float32"), - } - self.fetch_list = [ac_out] - - def test_check_output(self): - self.check_output() - - self.assertTrue( - PassVersionChecker.IsCompatible( - 'conv_eltwiseadd_affine_channel_fuse_pass')) - - -class ConvEltwiseAddAffineChannelFusePassValidPaddingTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 64, 64], dtype="float32") - param_attr = fluid.ParamAttr( - initializer=fluid.initializer.Xavier(uniform=False), - learning_rate=0.001) - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - groups=3, - padding='VALID', - bias_attr=param_attr, - act=None) - input_scale = fluid.layers.create_parameter( - shape=[3], dtype="float32") - input_bias = fluid.layers.create_parameter( - shape=[3], dtype="float32") - ac_out = fluid.layers.affine_channel( - x=conv_out, scale=input_scale, bias=input_bias) - - self.feeds = { - "data": np.random.random([1, 3, 64, 64]).astype("float32"), - } - self.fetch_list = [ac_out] - - def test_check_output(self): - self.check_output() - - self.assertTrue( - PassVersionChecker.IsCompatible( - 'conv_eltwiseadd_affine_channel_fuse_pass')) - - -class ConvEltwiseAddAffineChannelFusePassSamePaddingTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 64, 64], dtype="float32") - param_attr = fluid.ParamAttr( - initializer=fluid.initializer.Xavier(uniform=False), - learning_rate=0.001) - conv_out = fluid.layers.conv2d( - input=data, - num_filters=3, - filter_size=3, - groups=3, - padding='Same', - bias_attr=param_attr, - act=None) - input_scale = fluid.layers.create_parameter( - shape=[3], dtype="float32") - input_bias = fluid.layers.create_parameter( - shape=[3], dtype="float32") - ac_out = fluid.layers.affine_channel( - x=conv_out, scale=input_scale, bias=input_bias) - - self.feeds = { - "data": np.random.random([1, 3, 64, 64]).astype("float32"), - } - self.fetch_list = [ac_out] - - def test_check_output(self): - self.check_output() - - self.assertTrue( - PassVersionChecker.IsCompatible( - 'conv_eltwiseadd_affine_channel_fuse_pass')) +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestConvAffineChannelFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_config(self, draw): + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.integers(min_value=1, max_value=3)) + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + axis = draw(st.sampled_from([1])) + filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4 + filter_size = draw(st.integers(min_value=1, max_value=4)) + in_channel = groups * filter_channel + out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4 + out_channel = groups * out_channel_factor + batch_size = draw(st.integers(min_value=1, max_value=4)) + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + paddings = draw( + st.lists( + st.integers( + min_value=0, max_value=2), min_size=2, max_size=2)) + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + has_bias = draw(st.booleans()) + + x_shape = [ + batch_size, in_channel, 64, 64 + ] if data_format == "NCHW" else [batch_size, 64, 64, in_channel] + w_shape = [out_channel, filter_channel, filter_size, filter_size] + scale_shape = [out_channel] + bias_shape = [out_channel] + + def generate_input(): + return np.random.random(x_shape).astype(np.float32) + + def generate_weight(): + return np.random.random(w_shape).astype(np.float32) + + def generate_bias(): + return np.random.random(bias_shape).astype(np.float32) + + def generate_scale_bias(): + return np.random.random(bias_shape).astype(np.float32) + + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["input_data"], + "Filter": ["conv2d_weight"], + }, + outputs={"Output": ["conv_output"]}, + data_format=data_format, + dilations=dilations, + padding_algorithm=padding_algorithm, + groups=groups, + paddings=paddings, + strides=strides, + has_bias=has_bias, + is_test=True) + ac_op = OpConfig( + "affine_channel", + inputs={ + "X": ["conv_output"], + "Scale": ["affine_channel_scale"], + "Bias": ["affine_channel_bias"] + }, + outputs={"Out": ["affine_channel_ouput"]}, + data_layout=data_format) + if has_bias == True: + conv2d_op.inputs["Bias"] = ["conv2d_bias"] + ops = [conv2d_op, ac_op] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial(generate_weight)), + "affine_channel_scale": + TensorConfig(data_gen=partial(generate_scale_bias)), + "affine_channel_bias": + TensorConfig(data_gen=partial(generate_scale_bias)), + }, + outputs=["affine_channel_ouput"]) + if has_bias == True: + program_config.weights["conv2d_bias"] = TensorConfig( + data_gen=partial(generate_bias)) + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=True) + yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4) + + config = self.create_inference_config(use_mkldnn=True) + yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4) + + def add_ignore_pass_case(self): + # If the problem has been fixed, the judgment + # in is_program_valid needs to be deleted!!! + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs['data_format'] == "NHWC": + return True + return False + + # mkldnn Output has diff with bias! + def teller2(program_config, predictor_config): + return predictor_config.mkldnn_enabled() and program_config.ops[ + 0].attrs['has_bias'] == True + + self.add_ignore_check_case( + teller1, IgnoreReasons.PASS_ACCURACY_ERROR, + "The output format of conv2d is wrong when data_format attribute is NHWC, \ + because currently its fused op (Conv2DFusion) only supports data format of channel first (NCHW)." + ) + + self.add_ignore_check_case( + teller2, IgnoreReasons.PASS_ACCURACY_ERROR, + "Currently mkldnn Output has diff with bias!") + + def test(self): + self.run_and_statis( + quant=False, + passes=["conv_affine_channel_fuse_pass"], ) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_fuse_pass.py index 96b046edaec49038d6c8e137494c13bd7484cb7f..0bcee474d1394aa6633af6b3e9473a1c005d098b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_elementwise_add_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,41 +12,144 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import PassVersionChecker -from paddle.fluid.core import AnalysisConfig -"""Test for fusion of conv and elementwise_add.""" - - -class ConvElementwiseAddFusePassTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 100, 100], dtype="float32") - param_attr = fluid.ParamAttr( - initializer=fluid.initializer.Xavier(uniform=False), - learning_rate=0.001) - conv_out = fluid.layers.conv2d( - input=data, num_filters=3, filter_size=3, bias_attr=param_attr) - - self.feeds = { - "data": np.random.random((1, 3, 100, 100)).astype("float32") - } - self.fetch_list = [conv_out] - self.enable_mkldnn = False - - def test_check_output(self): - if core.is_compiled_with_cuda(): - use_gpu = True - self.check_output_with_option(use_gpu) - self.assertTrue( - PassVersionChecker.IsCompatible('conv_elementwise_add_fuse_pass')) +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestConvEltwiseAddFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if attrs[0]['data_format'] == "NHWC" and attrs[1]['axis'] != 3: + return False + + return True + + def sample_program_config(self, draw): + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.integers(min_value=1, max_value=3)) + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + axis = draw(st.sampled_from([1])) + filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4 + filter_size = draw(st.integers(min_value=1, max_value=4)) + in_channel = groups * filter_channel + out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4 + out_channel = groups * out_channel_factor + batch_size = draw(st.integers(min_value=1, max_value=4)) + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + paddings = draw( + st.lists( + st.integers( + min_value=0, max_value=2), min_size=2, max_size=2)) + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + + x_shape = [ + batch_size, in_channel, 64, 64 + ] if data_format == "NCHW" else [batch_size, 64, 64, in_channel] + w_shape = [out_channel, filter_channel, filter_size, filter_size] + scale_shape = [out_channel] + bias_shape = [out_channel] + + def generate_input(): + return np.random.random(x_shape).astype(np.float32) + + def generate_weight(): + return np.random.random(w_shape).astype(np.float32) + + def generate_bias(): + return np.random.random(bias_shape).astype(np.float32) + + def generate_scale_bias(): + return np.random.random(bias_shape).astype(np.float32) + + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["input_data"], + "Filter": ["conv2d_weight"], + }, + outputs={"Output": ["conv_output"]}, + data_format=data_format, + dilations=dilations, + padding_algorithm=padding_algorithm, + groups=groups, + paddings=paddings, + strides=strides, + is_test=True) + eltwise_op = OpConfig( + "elementwise_add", + inputs={"X": ["conv_output"], + "Y": ["conv2d_bias"]}, + outputs={"Out": ["elementwise_output"]}, + axis=axis) + ops = [conv2d_op, eltwise_op] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial(generate_weight)), + "conv2d_bias": + TensorConfig(data_gen=partial(generate_scale_bias)), + }, + outputs=["elementwise_output"]) + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=True) + yield config, ['conv2d_fusion'], (1e-4, 1e-4) + + # # TRT + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + workspace_size=1 << 20, + max_batch_size=4, + min_subgraph_size=1, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['conv2d_fusion'], (1e-4, 1e-4) + + def add_ignore_pass_case(self): + # If the problem has been fixed, the judgment + # in is_program_valid needs to be deleted!!! + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs['data_format'] == "NHWC": + return True + return False + + self.add_ignore_check_case( + teller1, IgnoreReasons.PASS_ACCURACY_ERROR, + "The output format of conv2d is wrong when data_format attribute is NHWC, \ + it will trigger Broadcast dimension mismatch bug \ + when data_format attribute is NHWC and axis of eltwise op is 1 for this pass." + ) + + def test(self): + self.run_and_statis( + quant=False, + passes=["conv_elementwise_add_fuse_pass"], ) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_affine_channel_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_affine_channel_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bfdb79ca1daa5caa0cffb945fee76fdef36c36 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_eltwiseadd_affine_channel_fuse_pass.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestConvEltwiseAddAffineChannelFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if attrs[0]['data_format'] == "NHWC" and attrs[1]['axis'] != 3: + return False + + return True + + def sample_program_config(self, draw): + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.integers(min_value=1, max_value=3)) + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + axis = draw(st.sampled_from([1])) + filter_channel = draw(st.integers(min_value=1, max_value=16)) * 4 + filter_size = draw(st.integers(min_value=1, max_value=4)) + in_channel = groups * filter_channel + out_channel_factor = draw(st.integers(min_value=1, max_value=16)) * 4 + out_channel = groups * out_channel_factor + batch_size = draw(st.integers(min_value=1, max_value=4)) + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + paddings = draw( + st.lists( + st.integers( + min_value=0, max_value=2), min_size=2, max_size=2)) + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + has_bias = draw(st.booleans()) + + x_shape = [ + batch_size, in_channel, 64, 64 + ] if data_format == "NCHW" else [batch_size, 64, 64, in_channel] + w_shape = [out_channel, filter_channel, filter_size, filter_size] + scale_shape = [out_channel] + bias_shape = [out_channel] + + def generate_input(): + return np.random.random(x_shape).astype(np.float32) + + def generate_weight(): + return np.random.random(w_shape).astype(np.float32) + + def generate_bias(): + return np.random.random(bias_shape).astype(np.float32) + + def generate_scale_bias(): + return np.random.random(bias_shape).astype(np.float32) + + conv2d_op = OpConfig( + "conv2d", + inputs={ + "Input": ["input_data"], + "Filter": ["conv2d_weight"], + }, + outputs={"Output": ["conv_output"]}, + data_format=data_format, + dilations=dilations, + padding_algorithm=padding_algorithm, + groups=groups, + paddings=paddings, + strides=strides, + has_bias=has_bias, + is_test=True) + eltwise_op = OpConfig( + "elementwise_add", + inputs={"X": ["conv_output"], + "Y": ["conv2d_bias"]}, + outputs={"Out": ["elementwise_output"]}, + axis=axis) + ac_op = OpConfig( + "affine_channel", + inputs={ + "X": ["elementwise_output"], + "Scale": ["affine_channel_scale"], + "Bias": ["affine_channel_bias"] + }, + outputs={"Out": ["affine_channel_ouput"]}, + data_layout=data_format) + if has_bias == True: + conv2d_op.inputs["Bias"] = ["conv2d_bias"] + ops = [conv2d_op, eltwise_op, ac_op] + program_config = ProgramConfig( + ops=ops, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + weights={ + "conv2d_weight": + TensorConfig(data_gen=partial(generate_weight)), + "conv2d_bias": TensorConfig(data_gen=partial(generate_bias)), + "affine_channel_scale": + TensorConfig(data_gen=partial(generate_scale_bias)), + "affine_channel_bias": + TensorConfig(data_gen=partial(generate_scale_bias)), + }, + outputs=["affine_channel_ouput"]) + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=True) + yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4) + + config = self.create_inference_config(use_mkldnn=True) + yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4) + + # TRT + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + workspace_size=1 << 20, + max_batch_size=4, + min_subgraph_size=1, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['conv2d', 'elementwise_add'], (1e-4, 1e-4) + + def add_ignore_pass_case(self): + # If the problem has been fixed, the judgment + # in is_program_valid needs to be deleted!!! + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs['data_format'] == "NHWC": + return True + return False + + # mkldnn Output has diff with bias! + def teller2(program_config, predictor_config): + return predictor_config.mkldnn_enabled() and program_config.ops[ + 0].attrs['has_bias'] == True + + self.add_ignore_check_case( + teller1, IgnoreReasons.PASS_ACCURACY_ERROR, + "The output format of conv2d is wrong when data_format attribute is NHWC, \ + it will trigger Broadcast dimension mismatch bug \ + when data_format attribute is NHWC and axis of eltwise op is 1 for this pass." + ) + + self.add_ignore_check_case( + teller2, IgnoreReasons.PASS_ACCURACY_ERROR, + "Currently mkldnn Output has diff with bias!") + + def test(self): + self.run_and_statis( + quant=False, + passes=["conv_eltwiseadd_affine_channel_fuse_pass"], ) + + +if __name__ == "__main__": + unittest.main()