From b4d3597a346eb18afa177a963144f10cf5348b08 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Thu, 17 Feb 2022 10:47:02 +0800 Subject: [PATCH] update inference ut to support nhwc format (#39551) * update inference ut to support nhwc format * update ut and pass OpCompat * update ut * update ut --- .../conv_activation_mkldnn_fuse_pass.cc | 2 +- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 4 ++-- .../conv_concat_relu_mkldnn_fuse_pass.cc | 2 +- .../test_conv_act_mkldnn_fuse_pass.py | 8 +++---- .../test_conv_bias_mkldnn_fuse_pass.py | 14 +++++------- ...kldnn_conv_concat_relu_mkldnn_fuse_pass.py | 9 -------- .../test_mkldnn_conv_gelu_fuse_pass.py | 22 ------------------- ...test_mkldnn_conv_hard_sigmoid_fuse_pass.py | 9 -------- .../test_mkldnn_conv_hard_swish_fuse_pass.py | 9 -------- ...st_mkldnn_conv_transpose_bias_fuse_pass.py | 12 +++++----- 10 files changed, 20 insertions(+), 71 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 6370d338036..453cfb85554 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() { // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute .AddAttr("data_format") .IsOptional() - .IsStringIn({"NCHW", "AnyLayout"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("relu")) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 9f70c829e1f..5d325037ad2 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") - .IsStringIn({"NCHW"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsTensor() .End() .AddAttr("axis") - .IsIntIn({1}) + .IsIntIn({1, 3}) .End(); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 0947f4756ad..5fbfef08b72 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("concat")) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py index d9b3b8e6017..d029bcd6a7f 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py @@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): data_format = prog_config.ops[0].attrs["data_format"] filter_shape = prog_config.weights["filter"].shape input_shape = prog_config.inputs["input_x"].shape - if data_format != "NCHW": - return False if padding_algorithm == "VALID": if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: @@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): x_shape = draw( st.lists( st.integers( - min_value=1, max_value=100), min_size=4, max_size=4)) - x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) # 2. Generate legal attr:data_format of conv2d data_format = draw(st.sampled_from(["NCHW", "NHWC"])) @@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): f_shape = draw( st.lists( st.integers( - min_value=1, max_value=7), min_size=4, max_size=4)) + min_value=1, max_value=5), min_size=4, max_size=4)) if data_format == "NCHW": f_shape[1] = x_shape[1] else: diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py index 40fd9a418b9..a0213c5b1f4 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py @@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): data_format = prog_config.ops[0].attrs["data_format"] filter_shape = prog_config.weights["filter"].shape input_shape = prog_config.inputs["input_x"].shape - if data_format != "NCHW": - return False if padding_algorithm == "VALID": if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: @@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): x_shape = draw( st.lists( st.integers( - min_value=1, max_value=100), min_size=4, max_size=4)) - x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) # 2. Generate legal attr:data_format of conv2d data_format = draw(st.sampled_from(["NCHW", "NHWC"])) @@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): f_shape = draw( st.lists( st.integers( - min_value=1, max_value=7), min_size=4, max_size=4)) + min_value=1, max_value=4), min_size=4, max_size=4)) if data_format == "NCHW": f_shape[1] = x_shape[1] else: @@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): strides = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=2, max_size=2)) + min_value=1, max_value=4), min_size=2, max_size=2)) # 5. Generate legal attr:padding_algorithm of conv2d padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) @@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): padding = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=4, max_size=4)) + min_value=1, max_value=4), min_size=4, max_size=4)) # 7. Generate legal attr:groups of conv2d groups = draw(st.integers(min_value=1, max_value=3)) @@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): dilations = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=2, max_size=2)) + min_value=1, max_value=4), min_size=2, max_size=2)) # 9. Generate legal shape of input:bias of elementwise_add bias_shape = [f_shape[0]] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py index 3c823c73d37..6654fbba264 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvConcatReluMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py index aa779f6ecbc..33df4283888 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvGeluMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): @@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): config = self.create_inference_config(use_mkldnn=True) yield config, ["conv2d"], (1e-5, 1e-5) - # If the problem has been fixed, the judgment - # needs to be deleted!!! - def add_ignore_pass_case(self): - def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['data_format'] == "NHWC": - return True - return False - - self.add_ignore_check_case( - teller1, SkipReasons.PASS_ACCURACY_ERROR, - "The output format of conv2d is wrong when data_format attribute is NHWC" - ) - def test(self): self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py index a0c4e183930..2eb071d6eb8 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py index 17bfb625fd3..990489c3213 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py index 5df7cb8d8ce..c5cedac2261 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py @@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): for i in range(len(program_config.ops)) ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": + if attrs[0]['data_format'] == "NCHW" and attrs[1]["axis"] == 3: + return False + if attrs[0]['data_format'] == "NHWC" and attrs[1]["axis"] == 1: return False return True @@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): groups = draw(st.sampled_from([1, 2, 4, 8])) paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - axis = draw(st.sampled_from([1])) + axis = draw(st.sampled_from([1, 3])) batch_size = draw(st.integers(min_value=1, max_value=4)) def generate_input(): @@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): def test(self): self.run_and_statis( - quant=False, passes=["conv_transpose_bias_mkldnn_fuse_pass"]) + quant=False, + max_duration=300, + passes=["conv_transpose_bias_mkldnn_fuse_pass"]) if __name__ == "__main__": -- GitLab