diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index a5beec87c399d35130d7aa11ee6fdb89e604c6bf..c33398553ecd2cbe291e9cc605aa23ce318e9efe 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -103,12 +104,32 @@ REGISTER_PASS(conv_activation_mkldnn_fuse_pass, REGISTER_PASS(conv_relu_mkldnn_fuse_pass, paddle::framework::ir::ConvActivationFusePass); +REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("relu", 0)); REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, paddle::framework::ir::Conv2DLeakyReLUFusePass); +REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .LE("leaky_relu", 1)); REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, paddle::framework::ir::Conv2DReLU6FusePass); +REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("relu6", 0)); REGISTER_PASS(conv_swish_mkldnn_fuse_pass, paddle::framework::ir::Conv2DSwishFusePass); +REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("swish", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 5fadd9607e9250c0bb890b0239c18b3a9096b55f..76e102125501144cbfd06ced2c88b4f1e02e261b 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" #include +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -123,3 +124,10 @@ void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, paddle::framework::ir::ConvConcatReLUFusePass); + +REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("concat", 0) + .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index 37c14e1d8e3b90f223c8dff7396d96594b9286d7..41b859f0af665eae6d9ccb6a08cd29db5ce67fdf 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h" #include #include +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -98,3 +99,10 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { REGISTER_PASS(matmul_transpose_reshape_fuse_pass, paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass); + +REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul", 0) + .EQ("transpose", 0) + .EQ("reshape", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index 790821e3fa4bbbcff23266f734d641169e231b70..0784a1a024cfd31cfb2d2a3ea205518416c2ad13 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -90,3 +91,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(scale_matmul_fuse_pass, paddle::framework::ir::ScaleMatmulFusePass); + +REGISTER_PASS_CAPABILITY(scale_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("scale", 0) + .EQ("matmul", 0)); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5d96994a33b2c05446b67df44bd8999352373d43 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_activation_fuse_pass.py @@ -0,0 +1,106 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class ConvActivationMkldnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + conv_out = fluid.layers.conv2d( + data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + bias_attr=self.conv_bias_attr, + act=self.act) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [conv_out] + self.enable_mkldnn = True + + def set_params(self): + self.conv_num_filters = 3 + self.conv_filter_size = 3 + self.conv_bias_attr = False + self.act = "relu" + self.pass_name = 'conv_relu_mkldnn_fuse_pass' + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class ConvActivationMkldnnFusePassTest_1(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 5 + self.conv_filter_size = 5 + self.conv_bias_attr = True + self.act = "relu" + self.pass_name = 'conv_relu_mkldnn_fuse_pass' + + +class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 3 + self.conv_filter_size = 3 + self.conv_bias_attr = False + self.act = "leaky_relu" + self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass' + + +class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 5 + self.conv_filter_size = 5 + self.conv_bias_attr = True + self.act = "leaky_relu" + self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass' + + +class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 3 + self.conv_filter_size = 3 + self.conv_bias_attr = False + self.act = "relu6" + self.pass_name = 'conv_relu6_mkldnn_fuse_pass' + + +class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest): + def set_params(self): + self.conv_num_filters = 5 + self.conv_filter_size = 5 + self.conv_bias_attr = True + self.act = "swish" + self.pass_name = 'conv_swish_mkldnn_fuse_pass' + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..45097f6b8191d045d0665d7478e4090c0ae20cb3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class ConvConcatReluMkldnnFusePassTest_0(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data_1 = fluid.data( + name="data_1", shape=[-1, 3, 100, 100], dtype="float32") + data_2 = fluid.data( + name="data_2", shape=[-1, 3, 100, 100], dtype="float32") + conv_1 = fluid.layers.conv2d( + data_1, + num_filters=self.conv1_num_filters, + filter_size=self.conv1_filter_size, + padding=self.conv1_padding, + bias_attr=self.conv1_bias_attr) + conv_2 = fluid.layers.conv2d( + data_2, + num_filters=self.conv2_num_filters, + filter_size=self.conv2_filter_size, + padding=self.conv2_padding, + bias_attr=self.conv2_bias_attr) + concat = fluid.layers.concat( + [conv_1, conv_2], axis=self.concat_axis) + out = fluid.layers.relu(concat) + + self.feeds = { + "data_1": np.random.random((1, 3, 100, 100)).astype("float32"), + "data_2": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [out] + self.enable_mkldnn = True + + def set_params(self): + self.conv1_num_filters = 3 + self.conv1_filter_size = 3 + self.conv1_padding = 0 + self.conv1_bias_attr = False + self.conv2_num_filters = 3 + self.conv2_filter_size = 3 + self.conv2_padding = 0 + self.conv2_bias_attr = False + self.concat_axis = 0 + self.pass_name = "conv_concat_relu_mkldnn_fuse_pass" + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class ConvConcatReluMkldnnFusePassTest_1(ConvConcatReluMkldnnFusePassTest_0): + def set_params(self): + self.conv1_num_filters = 3 + self.conv1_filter_size = 3 + self.conv1_padding = 0 + self.conv1_bias_attr = False + self.conv2_num_filters = 5 + self.conv2_filter_size = 5 + self.conv2_padding = 1 + self.conv2_bias_attr = True + self.concat_axis = 1 + self.pass_name = "conv_concat_relu_mkldnn_fuse_pass" + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b5e0e54739b37b5f2e490fc890fbe56c2f83f2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_transpose_reshape_fuse_pass.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class MatmulTransposeReshapeMkldnnFusePassTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=self.data_shape, dtype="float32") + weight = fluid.layers.create_parameter( + shape=self.weight_shape, dtype="float32") + matmul = fluid.layers.matmul( + data, + weight, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y) + transpose = fluid.layers.transpose(matmul, self.tranpose_perm) + reshape = fluid.layers.reshape(transpose, shape=self.reshape_shape) + + self.fetch_list = [reshape] + self.enable_mkldnn = True + + def set_params(self): + self.data_shape = [-1, 3, 100, 110] + self.weight_shape = [1, 3, 110, 100] + self.feeds = { + "data": np.random.random((1, 3, 100, 110)).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + self.tranpose_perm = [0, 2, 1, 3] + self.reshape_shape = [3, 100, 100] + self.pass_name = 'matmul_transpose_reshape_fuse_pass' + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class MatmulTransposeReshapeMkldnnFusePassTest_1( + MatmulTransposeReshapeMkldnnFusePassTest): + def set_params(self): + self.data_shape = [-1, 3, 100, 100] + self.weight_shape = [1, 3, 100, 100] + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.transpose_x = True + self.transpose_y = True + self.tranpose_perm = [0, 2, 1, 3] + self.reshape_shape = [6, 50, 100] + self.pass_name = 'matmul_transpose_reshape_fuse_pass' + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_relu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_scale_matmul_fuse_pass.py similarity index 50% rename from python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_relu_fuse_pass.py rename to python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_scale_matmul_fuse_pass.py index 2346e93d64dce21d9bdd7687bd8d5ed38ff5f188..55a6b543f0aeafe75940255565e3f02ae9194b99 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_relu_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_scale_matmul_fuse_pass.py @@ -20,26 +20,54 @@ from inference_pass_test import InferencePassTest import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker -class ConvBnFusePassMKLDNNTest(InferencePassTest): +class ScaleMatmulMkldnnFusePassTest(InferencePassTest): def setUp(self): + self.set_params() with fluid.program_guard(self.main_program, self.startup_program): data = fluid.data( - name="data", shape=[-1, 3, 100, 100], dtype="float32") - conv_out = fluid.layers.conv2d( - data, num_filters=3, filter_size=3, bias_attr=False, act="relu") + name="data", shape=[1, 3, 100, 100], dtype="float32") + weight = fluid.layers.create_parameter( + shape=[1, 3, 100, 100], dtype="float32") + scale = fluid.layers.scale(data, scale=self.scale_scale) + matmul = fluid.layers.matmul( + scale, + weight, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y) + self.fetch_list = [matmul] + self.enable_mkldnn = True + + def set_params(self): self.feeds = { "data": np.random.random((1, 3, 100, 100)).astype("float32") } - self.fetch_list = [conv_out] - self.enable_mkldnn = True + self.scale_scale = 2.0 + self.transpose_x = False + self.transpose_y = False + self.pass_name = "scale_matmul_fuse_pass" def test_check_output(self): use_gpu = False self.check_output_with_option(use_gpu) + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class ScaleMatmulMkldnnFusePassTest_1(ScaleMatmulMkldnnFusePassTest): + def set_params(self): + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.scale_scale = 5.0 + self.transpose_x = True + self.transpose_y = True + self.pass_name = "scale_matmul_fuse_pass" + if __name__ == "__main__": unittest.main()