diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 6370d3380361c0c3b3434bfb9b73f429174f4cfa..453cfb85554ec53549bcbfd1f9be4566deb47d54 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() { // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute .AddAttr("data_format") .IsOptional() - .IsStringIn({"NCHW", "AnyLayout"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("relu")) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 9f70c829e1fb5ea8c13c45bdacbc408d8cc40173..5d325037ad20ed7d1a27d6af0b302304bf8ff8e6 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") - .IsStringIn({"NCHW"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsTensor() .End() .AddAttr("axis") - .IsIntIn({1}) + .IsIntIn({1, 3}) .End(); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 0947f4756ad78c0e2b4b255e96f75ea440b022ef..5fbfef08b7209bc695f90ff9188b8e9a7db029a7 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("concat")) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py index d9b3b8e601755592609c4008b17cbefac81a7caf..d029bcd6a7f17b670ec544d9aa0279d5b4177310 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py @@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): data_format = prog_config.ops[0].attrs["data_format"] filter_shape = prog_config.weights["filter"].shape input_shape = prog_config.inputs["input_x"].shape - if data_format != "NCHW": - return False if padding_algorithm == "VALID": if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: @@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): x_shape = draw( st.lists( st.integers( - min_value=1, max_value=100), min_size=4, max_size=4)) - x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) # 2. Generate legal attr:data_format of conv2d data_format = draw(st.sampled_from(["NCHW", "NHWC"])) @@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): f_shape = draw( st.lists( st.integers( - min_value=1, max_value=7), min_size=4, max_size=4)) + min_value=1, max_value=5), min_size=4, max_size=4)) if data_format == "NCHW": f_shape[1] = x_shape[1] else: diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py index 40fd9a418b9b1391bb9ebf6f36d1e73f55d48c14..a0213c5b1f4dfd5f542c0a8fab32e9a8c93a66c6 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py @@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): data_format = prog_config.ops[0].attrs["data_format"] filter_shape = prog_config.weights["filter"].shape input_shape = prog_config.inputs["input_x"].shape - if data_format != "NCHW": - return False if padding_algorithm == "VALID": if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: @@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): x_shape = draw( st.lists( st.integers( - min_value=1, max_value=100), min_size=4, max_size=4)) - x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) # 2. Generate legal attr:data_format of conv2d data_format = draw(st.sampled_from(["NCHW", "NHWC"])) @@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): f_shape = draw( st.lists( st.integers( - min_value=1, max_value=7), min_size=4, max_size=4)) + min_value=1, max_value=4), min_size=4, max_size=4)) if data_format == "NCHW": f_shape[1] = x_shape[1] else: @@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): strides = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=2, max_size=2)) + min_value=1, max_value=4), min_size=2, max_size=2)) # 5. Generate legal attr:padding_algorithm of conv2d padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) @@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): padding = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=4, max_size=4)) + min_value=1, max_value=4), min_size=4, max_size=4)) # 7. Generate legal attr:groups of conv2d groups = draw(st.integers(min_value=1, max_value=3)) @@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): dilations = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=2, max_size=2)) + min_value=1, max_value=4), min_size=2, max_size=2)) # 9. Generate legal shape of input:bias of elementwise_add bias_shape = [f_shape[0]] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py index 3c823c73d37943cb9040522b21bdf3f642c10bea..6654fbba264e0ee007ea16d0c3b8131e84da01ad 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvConcatReluMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py index aa779f6ecbc4f99ba0ae20f4a6a1a7f674f46781..33df428388882f2e536ecebb15a1f5dae6a6afc5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvGeluMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): @@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): config = self.create_inference_config(use_mkldnn=True) yield config, ["conv2d"], (1e-5, 1e-5) - # If the problem has been fixed, the judgment - # needs to be deleted!!! - def add_ignore_pass_case(self): - def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['data_format'] == "NHWC": - return True - return False - - self.add_ignore_check_case( - teller1, SkipReasons.PASS_ACCURACY_ERROR, - "The output format of conv2d is wrong when data_format attribute is NHWC" - ) - def test(self): self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py index a0c4e183930a570e25bb6aed893868f69a4edcd8..2eb071d6eb83bd89762145d4212732b6d6ad54f7 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py index 17bfb625fd37bc547a2c3eaece361f7cce5d6615..990489c32136a20e88f29591db7193aa194d00a9 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py index 5df7cb8d8cec35fd62c373e46d5b799bea46a2da..c5cedac226149d8dd30e0b25d0b7681b587cf3c5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py @@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): for i in range(len(program_config.ops)) ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": + if attrs[0]['data_format'] == "NCHW" and attrs[1]["axis"] == 3: + return False + if attrs[0]['data_format'] == "NHWC" and attrs[1]["axis"] == 1: return False return True @@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): groups = draw(st.sampled_from([1, 2, 4, 8])) paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - axis = draw(st.sampled_from([1])) + axis = draw(st.sampled_from([1, 3])) batch_size = draw(st.integers(min_value=1, max_value=4)) def generate_input(): @@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): def test(self): self.run_and_statis( - quant=False, passes=["conv_transpose_bias_mkldnn_fuse_pass"]) + quant=False, + max_duration=300, + passes=["conv_transpose_bias_mkldnn_fuse_pass"]) if __name__ == "__main__":