diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 336ccc9b77df8fd5a92f03be9ba83bf76ae67096..2b6d9f98abba02a9131ff4bb519838da355d6e5a 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -218,6 +218,7 @@ if(WITH_MKLDNN) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2759c79b7a7d93526e2b90342d82191e71c25c83 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const { + const std::vector fusable_ops{"fc", "matmul", "matmul_v2"}; + for (const auto &op : fusable_ops) FuseScale(graph, op); +} + +void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph, + const std::string &op_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(op_type + "_scale_onednn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::OperatorActivation op_scale_pattern( + gpd.mutable_pattern(), op_type + "_scale_onednn_fuse_pass"); + op_scale_pattern(op_type, "scale"); + + int found_operator_scale_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(operator_out, preceding_op_out, op_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_op, activation, op_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, activation_out, op_scale_pattern); + + if (operator_op->Op()->HasAttr("use_mkldnn") && + !(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn")))) { + VLOG(4) << "Only oneDNN version of " << op_type + << "can be fused with scale."; + return; + } + + if (scale_op->Op()->GetAttrIfExists("bias") != 0.0) { + VLOG(4) << op_type << " can be fused only with unbiased scale."; + return; + } + + float scale = PADDLE_GET_CONST(float, scale_op->Op()->GetAttr("scale")); + + auto *scope = param_scope(); + auto const &names = scale_op->Op()->InputNames(); + bool has_scale_tensor = + std::find(names.begin(), names.end(), "ScaleTensor") != names.end(); + + if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) { + std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front(); + auto *scale_var = scope->FindVar(scale_var_name); + // ScaleTensor must be weight + if (scale_var == nullptr) return; + auto *scale_tensor = scale_var->GetMutable(); + scale = *(scale_tensor->data()); + } + + operator_op->Op()->SetAttr("fused_output_scale", scale); + operator_op->Op()->SetOutput("Out", {scale_out->Name()}); + + IR_OP_VAR_LINK(operator_op, scale_out); + GraphSafeRemoveNodes(g, {scale_op, operator_out}); + found_operator_scale_count++; + }; + + gpd(graph, handler); + AddStatis(found_operator_scale_count); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_operator_scale_count > 0) + PrettyLogDetail( + "--- fused %d %s with scale", found_operator_scale_count, op_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(operator_scale_onednn_fuse_pass, + paddle::framework::ir::FuseOperatorScaleOneDNNPass); +REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("fc", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("scale", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..e4e0295bf5604b506042e387f5dd30a05db42a37 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseOperatorScaleOneDNNPass : public FusePassBase { + public: + virtual ~FuseOperatorScaleOneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + + void FuseScale(Graph *graph, const std::string &op_type) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 31e3f453fc6b936c1abdfabcc335a4f9da21f0e4..2a6f47be48723d41f1ef3aea9de3ba7c2e67a7dd 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -326,6 +326,7 @@ void CpuPassStrategy::EnableMKLDNN() { "softplus_activation_mkldnn_fuse_pass", // "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // + "operator_scale_onednn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after @@ -419,6 +420,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); + passes_.push_back("operator_scale_onednn_fuse_pass"); passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_squash_pass"); diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 7404972ea7cca0177a157c127055abaaf7e91046..93b27b2caac02e4ef37b909ae5b9e821a6d72a94 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -533,6 +533,12 @@ class FCPrimitiveFactory { scale, dnnl::algorithm::eltwise_hardswish, alpha, beta); } + if (ctx.HasAttr("fused_output_scale")) { + float scale_alpha = ctx.Attr("fused_output_scale"); + post_operations.append_eltwise( + 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); + } + attributes.set_post_ops(post_operations); return attributes; } diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 959329995613023a8c5d210da91ca11c76fab82d..9b870af90a1782596616b3a4158a6d57247b0b80 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -250,6 +250,12 @@ class MatMulV2MKLDNNHandler AppendActivation(ctx, post_operations); + if (ctx.HasAttr("fused_output_scale")) { + float scale_alpha = ctx.Attr("fused_output_scale"); + post_operations.append_eltwise( + 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); + } + matmul_attrs.set_post_ops(post_operations); return matmul_attrs; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py index b894fc708b4243b802788981d7c7a34f43d6ad0d..80b13cb82dd017ae13776af29dc71b1e1be770cb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py @@ -21,7 +21,6 @@ import hypothesis.strategies as st class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): - def sample_program_config(self, draw): transpose_X = draw(st.booleans()) transpose_Y = draw(st.booleans()) @@ -30,11 +29,25 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): channel = draw(st.sampled_from([8])) input_dim = draw(st.sampled_from([32])) activation_type = draw( - st.sampled_from([ - 'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish', - 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid', - 'leaky_relu' - ])) + st.sampled_from( + [ + 'relu', + 'gelu', + 'swish', + 'mish', + 'sqrt', + 'hard_swish', + 'sigmoid', + 'abs', + 'relu6', + 'clip', + 'tanh', + 'hard_sigmoid', + 'leaky_relu', + 'scale', + ] + ) + ) def generate_input(type): if transpose_X and transpose_Y: @@ -55,50 +68,60 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): else: return np.random.random(shape_y).astype(np.float32) - matmul_op = OpConfig(type='matmul', - inputs={ - 'X': ['matmul_X'], - 'Y': ['matmul_Y'] - }, - outputs={'Out': ['matmul_output']}, - attrs={ - 'transpose_X': transpose_X, - 'transpose_Y': transpose_Y, - 'alpha': alpha - }) + matmul_op = OpConfig( + type='matmul', + inputs={'X': ['matmul_X'], 'Y': ['matmul_Y']}, + outputs={'Out': ['matmul_output']}, + attrs={ + 'transpose_X': transpose_X, + 'transpose_Y': transpose_Y, + 'alpha': alpha, + 'use_mkldnn': True, + }, + ) if activation_type == "relu6": - activation_op = OpConfig(activation_type, - inputs={"X": ["matmul_output"]}, - outputs={"Out": ["activation_output"]}, - threshold=draw( - st.floats(min_value=1.0, - max_value=10.0))) + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + ) elif activation_type == "leaky_relu": - activation_op = OpConfig(activation_type, - inputs={"X": ["matmul_output"]}, - outputs={"Out": ["activation_output"]}, - alpha=draw( - st.floats(min_value=0.1, - max_value=1.0))) + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + alpha=draw(st.floats(min_value=0.1, max_value=1.0)), + ) + elif activation_type == "scale": + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + scale=draw(st.sampled_from([0.125, 0.4, 0.875, 2])), + ) elif activation_type == "swish": - activation_op = OpConfig(activation_type, - inputs={"X": ["matmul_output"]}, - outputs={"Out": ["activation_output"]}, - beta=draw( - st.floats(min_value=0.1, - max_value=1.0))) + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + beta=draw(st.floats(min_value=0.1, max_value=1.0)), + ) elif activation_type == "clip": activation_op = OpConfig( activation_type, inputs={"X": ["matmul_output"]}, outputs={"Out": ["activation_output"]}, min=draw(st.floats(min_value=0.1, max_value=0.49)), - max=draw(st.floats(min_value=0.5, max_value=1.0))) + max=draw(st.floats(min_value=0.5, max_value=1.0)), + ) else: - activation_op = OpConfig(activation_type, - inputs={"X": ["matmul_output"]}, - outputs={"Out": ["activation_output"]}) + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + ) model_net = [matmul_op, activation_op] @@ -107,20 +130,32 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): weights={}, inputs={ 'matmul_X': TensorConfig(data_gen=partial(generate_input, 'x')), - 'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y')) + 'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y')), }, - outputs=['activation_output']) + outputs=['activation_output'], + ) return program_config def sample_predictor_configs(self, program_config): - config = self.create_inference_config(use_mkldnn=True) + config = self.create_inference_config( + use_mkldnn=True, + passes=[ + 'matmul_activation_mkldnn_fuse_pass', + 'operator_scale_onednn_fuse_pass', + ], + ) yield config, ['matmul'], (1e-5, 1e-5) def test(self): - self.run_and_statis(quant=False, - max_examples=30, - passes=['matmul_activation_mkldnn_fuse_pass']) + self.run_and_statis( + quant=False, + max_examples=50, + passes=[ + 'matmul_activation_mkldnn_fuse_pass', + 'operator_scale_onednn_fuse_pass', + ], + ) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py index 153b81fa797af560fa56898db6c6a1ce54719215..84fc91e01620be045157cc8218f9925d45c0ba9d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py @@ -21,7 +21,6 @@ import hypothesis.strategies as st class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): - def sample_program_config(self, draw): transpose_X = draw(st.booleans()) transpose_Y = draw(st.booleans()) @@ -29,11 +28,25 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): channel = draw(st.sampled_from([16, 32, 64])) input_dim = draw(st.sampled_from([16, 32, 64])) activation_type = draw( - st.sampled_from([ - 'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish', - 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid', - 'leaky_relu' - ])) + st.sampled_from( + [ + 'relu', + 'gelu', + 'swish', + 'mish', + 'sqrt', + 'hard_swish', + 'sigmoid', + 'abs', + 'relu6', + 'clip', + 'tanh', + 'hard_sigmoid', + 'leaky_relu', + 'scale', + ] + ) + ) def generate_input(type): broadcast_X = st.booleans() @@ -60,49 +73,59 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): else: return np.random.random(shape_y).astype(np.float32) - matmul_op = OpConfig(type='matmul_v2', - inputs={ - 'X': ['matmul_X'], - 'Y': ['matmul_Y'] - }, - outputs={'Out': ['matmul_output']}, - attrs={ - 'trans_x': transpose_X, - 'trans_y': transpose_Y - }) + matmul_op = OpConfig( + type='matmul_v2', + inputs={'X': ['matmul_X'], 'Y': ['matmul_Y']}, + outputs={'Out': ['matmul_output']}, + attrs={ + 'trans_x': transpose_X, + 'trans_y': transpose_Y, + 'use_mkldnn': True, + }, + ) if activation_type == 'relu6': - activation_op = OpConfig(activation_type, - inputs={'X': ['matmul_output']}, - outputs={'Out': ['activation_output']}, - threshold=draw( - st.floats(min_value=1.0, - max_value=10.0))) + activation_op = OpConfig( + activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + threshold=draw(st.floats(min_value=1.0, max_value=10.0)), + ) elif activation_type == 'leaky_relu': - activation_op = OpConfig(activation_type, - inputs={'X': ['matmul_output']}, - outputs={'Out': ['activation_output']}, - alpha=draw( - st.floats(min_value=0.1, - max_value=1.0))) + activation_op = OpConfig( + activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + alpha=draw(st.floats(min_value=0.1, max_value=1.0)), + ) + elif activation_type == "scale": + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + scale=draw(st.sampled_from([0.125, 0.4, 0.875, 2])), + ) elif activation_type == 'swish': - activation_op = OpConfig(activation_type, - inputs={'X': ['matmul_output']}, - outputs={'Out': ['activation_output']}, - beta=draw( - st.floats(min_value=0.1, - max_value=1.0))) + activation_op = OpConfig( + activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + beta=draw(st.floats(min_value=0.1, max_value=1.0)), + ) elif activation_type == 'clip': activation_op = OpConfig( activation_type, inputs={'X': ['matmul_output']}, outputs={'Out': ['activation_output']}, min=draw(st.floats(min_value=0.1, max_value=0.49)), - max=draw(st.floats(min_value=0.5, max_value=1.0))) + max=draw(st.floats(min_value=0.5, max_value=1.0)), + ) else: - activation_op = OpConfig(activation_type, - inputs={'X': ['matmul_output']}, - outputs={'Out': ['activation_output']}) + activation_op = OpConfig( + activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + ) model_net = [matmul_op, activation_op] @@ -111,20 +134,32 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): weights={}, inputs={ 'matmul_X': TensorConfig(data_gen=partial(generate_input, 'X')), - 'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'Y')) + 'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'Y')), }, - outputs=['activation_output']) + outputs=['activation_output'], + ) return program_config def sample_predictor_configs(self, program_config): - config = self.create_inference_config(use_mkldnn=True) + config = self.create_inference_config( + use_mkldnn=True, + passes=[ + 'matmul_activation_mkldnn_fuse_pass', + 'operator_scale_onednn_fuse_pass', + ], + ) yield config, ['matmul_v2'], (1e-5, 1e-5) def test(self): - self.run_and_statis(quant=False, - max_examples=30, - passes=['matmul_activation_mkldnn_fuse_pass']) + self.run_and_statis( + quant=False, + max_examples=50, + passes=[ + 'matmul_activation_mkldnn_fuse_pass', + 'operator_scale_onednn_fuse_pass', + ], + ) if __name__ == '__main__':