From fdcfa04f6a6029371be2de164b7f6bfe62e204f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 13 Mar 2023 04:36:56 +0100 Subject: [PATCH] Fused softplus (#51087) * mkldnn->onednn * fused softplus op + kernel * remove extra attributes * add missing handler * change var name --- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- .../framework/ir/mkldnn/mkldnn_pass_util.h | 3 +- ...> softplus_activation_onednn_fuse_pass.cc} | 10 +-- ...=> softplus_activation_onednn_fuse_pass.h} | 2 +- .../inference/api/paddle_pass_builder.cc | 4 +- .../operators/compat/fused_softplus.pbtxt | 31 +++++++++ .../operators/fused/fused_softplus_op.cc | 69 +++++++++++++++++++ paddle/fluid/operators/ops_extra_info.h | 5 -- paddle/phi/api/yaml/op_compat.yaml | 3 +- paddle/phi/backends/onednn/onednn_reuse.h | 41 +++++++++++ .../fusion/onednn/fused_softplus_kernel.cc | 65 +++++++++++++++++ paddle/phi/kernels/onednn/softplus_kernel.cc | 45 +----------- paddle/phi/ops/compat/fused_softplus_sig.cc | 30 ++++++++ ...st_onednn_softplus_activation_fuse_pass.py | 4 +- .../quantization/quant2_int8_mkldnn_pass.py | 2 +- 15 files changed, 255 insertions(+), 61 deletions(-) rename paddle/fluid/framework/ir/mkldnn/{softplus_activation_mkldnn_fuse_pass.cc => softplus_activation_onednn_fuse_pass.cc} (90%) rename paddle/fluid/framework/ir/mkldnn/{softplus_activation_mkldnn_fuse_pass.h => softplus_activation_onednn_fuse_pass.h} (95%) create mode 100644 paddle/fluid/operators/compat/fused_softplus.pbtxt create mode 100644 paddle/fluid/operators/fused/fused_softplus_op.cc create mode 100644 paddle/phi/kernels/fusion/onednn/fused_softplus_kernel.cc create mode 100644 paddle/phi/ops/compat/fused_softplus_sig.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5dd1b4c6193..6655aa19db2 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -172,7 +172,7 @@ if(WITH_MKLDNN) pass_library(cpu_bfloat16_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(interpolate_mkldnn_pass inference DIR mkldnn) - pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(softplus_activation_onednn_fuse_pass inference DIR mkldnn) pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h index 142bb9adb68..3885ef37e2f 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h @@ -160,7 +160,8 @@ inline void ConvertToFusedOp(OpDesc* op) { {"conv2d", "fused_conv2d"}, {"depthwise_conv2d", "fused_conv2d"}, {"matmul", "fused_matmul"}, - {"matmul_v2", "fused_matmul"}}; + {"matmul_v2", "fused_matmul"}, + {"softplus", "fused_softplus"}}; if (op->Type() == "matmul") { op->SetAttr("trans_x", op->GetAttr("transpose_X")); diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.cc similarity index 90% rename from paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc rename to paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.cc index 3fc9221260d..2030a7dadc0 100644 --- a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/enforce.h" #include "paddle/utils/string/pretty_log.h" @@ -58,6 +59,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( GET_IR_NODE_FROM_SUBGRAPH( activation, activation, softplus_activation_pattern); + ConvertToFusedOp(softplus->Op()); SetActivationAttrs(softplus->Op(), activation->Op(), act_type); softplus->Op()->SetOutput("Out", {activation_out->Name()}); @@ -78,9 +80,9 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( } // namespace framework } // namespace paddle -REGISTER_PASS(softplus_activation_mkldnn_fuse_pass, +REGISTER_PASS(softplus_activation_onednn_fuse_pass, paddle::framework::ir::SoftplusActivationOneDNNPass); -REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass) +REGISTER_PASS_CAPABILITY(softplus_activation_onednn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("softplus", 1) diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.h similarity index 95% rename from paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h rename to paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.h index 6368a102b0e..817797e4579 100644 --- a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index d43770e0ddb..f2326b66561 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -375,7 +375,7 @@ void CpuPassStrategy::EnableMKLDNN() { "fc_act_mkldnn_fuse_pass", "fc_elementwise_add_mkldnn_fuse_pass", // "batch_norm_act_fuse_pass", // - "softplus_activation_mkldnn_fuse_pass", // + "softplus_activation_onednn_fuse_pass", // "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // "layer_norm_onednn_optimization_pass", // @@ -467,7 +467,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("fc_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass"); passes_.push_back("batch_norm_act_fuse_pass"); - passes_.push_back("softplus_activation_mkldnn_fuse_pass"); + passes_.push_back("softplus_activation_onednn_fuse_pass"); passes_.push_back("compute_propagate_scales_mkldnn_pass"); passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); diff --git a/paddle/fluid/operators/compat/fused_softplus.pbtxt b/paddle/fluid/operators/compat/fused_softplus.pbtxt new file mode 100644 index 00000000000..030530e9dce --- /dev/null +++ b/paddle/fluid/operators/compat/fused_softplus.pbtxt @@ -0,0 +1,31 @@ +type: "fused_softplus" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + attrs { + name: "beta" + type: FLOAT + } + attrs { + name: "threshold" + type: FLOAT + } +} +extra { + attrs { + name: "fuse_activation" + type: STRING + } + attrs { + name: "fuse_alpha" + type: FLOAT + } + attrs { + name: "fuse_beta" + type: FLOAT + } +} diff --git a/paddle/fluid/operators/fused/fused_softplus_op.cc b/paddle/fluid/operators/fused/fused_softplus_op.cc new file mode 100644 index 00000000000..2e0d8ca7d91 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_softplus_op.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class FusedSoftplusOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = this->IndicateVarDataType(ctx, "X"); + return phi::KernelKey(data_type, ctx.GetPlace()); + } +}; + +class FusedSoftplusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of softplus operator"); + AddOutput("Out", "Output of softplus operator"); + AddAttr("beta", "Beta value for the softplus formulation") + .SetDefault(1.0f); + AddAttr("threshold", "Values above this revert to a linear function") + .SetDefault(20.0f); + AddAttr( + "fuse_activation", + "Activation type from softplus_activation_onednn_fuse_pass") + .SetDefault(""); + AddAttr("fuse_alpha", + "Activation alpha from softplus_activation_onednn_fuse_pass") + .SetDefault(0.0f); + AddAttr("fuse_beta", + "Activation beta from softplus_activation_onednn_fuse_pass") + .SetDefault(0.0f); + AddComment(R"DOC(Softplus extended with oneDNN-specific fusion logic.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_softplus, + ops::FusedSoftplusOp, + ops::FusedSoftplusOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/ops_extra_info.h b/paddle/fluid/operators/ops_extra_info.h index 0f7f6d8b21c..2182edf43bc 100644 --- a/paddle/fluid/operators/ops_extra_info.h +++ b/paddle/fluid/operators/ops_extra_info.h @@ -91,11 +91,6 @@ const std::unordered_map {"data_format", ExtraAttrProperty::ONEDNN}, {"force_fp32_output", ExtraAttrProperty::ONEDNN}, {"fuse_activation", ExtraAttrProperty::ONEDNN}, - {"fuse_activation_type", ExtraAttrProperty::ONEDNN}, - {"fuse_activation_alpha", ExtraAttrProperty::ONEDNN}, - {"fuse_activation_beta", ExtraAttrProperty::ONEDNN}, - {"fuse_activation_scale", ExtraAttrProperty::ONEDNN}, - {"fused_output_scale", ExtraAttrProperty::ONEDNN}, {"fuse_alpha", ExtraAttrProperty::ONEDNN}, {"fuse_beta", ExtraAttrProperty::ONEDNN}, {"fuse_relu", ExtraAttrProperty::ONEDNN}, diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index b16371895f9..d5b834404b9 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1622,8 +1622,7 @@ outputs : out : Out extra : - attrs : [bool use_mkldnn = false, bool use_cudnn = false, str fuse_activation_type = "", float fuse_activation_alpha = 0.0f, - float fuse_activation_beta = 0.0f, float fuse_activation_scale = 1.0f] + attrs : [bool use_mkldnn = false, bool use_cudnn = false] - op : softshrink backward : softshrink_grad diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 5193bc6d0b6..cbd503e2498 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1623,6 +1623,47 @@ class PoolingOneDNNHandler } }; +template +class SoftplusOneDNNHandler : public OneDNNHandlerNoCachingT { + public: + SoftplusOneDNNHandler(const OneDNNContext& dev_ctx, + const phi::DenseTensor* x, + const float beta, + const std::string& fuse_activation = "", + const float fuse_alpha = 0.0f, + const float fuse_beta = 0.0f) + : OneDNNHandlerNoCachingT(dev_ctx.GetEngine(), + dev_ctx.GetPlace()) { + dnnl::post_ops post_ops; + post_ops.append_eltwise( + 1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f); + if (beta != 1.0f) { + post_ops.append_eltwise( + 1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f); + } + AppendActivation( + dev_ctx, post_ops, 1.0f, fuse_activation, fuse_alpha, fuse_beta); + dnnl::primitive_attr attrs; + attrs.set_post_ops(post_ops); + + auto x_tz = phi::vectorize(x->dims()); + auto beta_tz = std::vector(x_tz.size(), 1); + auto beta_md = dnnl::memory::desc( + beta_tz, OneDNNGetDataType(), GetPlainOneDNNFormat(x_tz.size())); + + this->AcquireForwardPrimitiveDescriptor(attrs, + dnnl::algorithm::binary_mul, + x->mem_desc(), + beta_md, + x->mem_desc()); + } + + std::shared_ptr AcquireBetaMemory(const float* beta) { + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(), + to_void_cast(beta)); + } +}; + static void SetOutMemDescWithUnsqueeze2FuseSupport( const std::vector fused_unsqueeze2_axes, phi::DenseTensor* out, diff --git a/paddle/phi/kernels/fusion/onednn/fused_softplus_kernel.cc b/paddle/phi/kernels/fusion/onednn/fused_softplus_kernel.cc new file mode 100644 index 00000000000..cbc8b37f51f --- /dev/null +++ b/paddle/phi/kernels/fusion/onednn/fused_softplus_kernel.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/activation_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedSoftplusKernel(const Context& dev_ctx, + const DenseTensor& x, + float beta, + float threshold, + const std::string& fuse_activation, + const float fuse_alpha, + const float fuse_beta, + DenseTensor* out) { + funcs::SoftplusOneDNNHandler handler( + dev_ctx, &x, beta, fuse_activation, fuse_alpha, fuse_beta); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto beta_memory_p = handler.AcquireBetaMemory(&beta); + std::shared_ptr dst_memory_p = nullptr; + if (x.IsSharedBufferWith(*out)) { + dst_memory_p = src_memory_p; + dev_ctx.template Alloc(out); + } else { + dst_memory_p = handler.AcquireDstMemory(out); + } + auto softplus_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_memory_p}, + {DNNL_ARG_SRC_1, *beta_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + softplus_p->execute(astream, args); + astream.wait(); + + out->set_mem_desc(dst_memory_p->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(fused_softplus, + OneDNN, + ONEDNN, + phi::FusedSoftplusKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/softplus_kernel.cc b/paddle/phi/kernels/onednn/softplus_kernel.cc index b87938e3dc1..642beac2e7a 100644 --- a/paddle/phi/kernels/onednn/softplus_kernel.cc +++ b/paddle/phi/kernels/onednn/softplus_kernel.cc @@ -19,52 +19,13 @@ namespace phi { -template -class SoftplusOneDNNHandler - : public funcs::OneDNNHandlerNoCachingT { - public: - SoftplusOneDNNHandler(const OneDNNContext& dev_ctx, - const phi::DenseTensor* x, - const float beta) - : funcs::OneDNNHandlerNoCachingT(dev_ctx.GetEngine(), - dev_ctx.GetPlace()) { - dnnl::post_ops post_ops; - post_ops.append_eltwise( - 1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f); - if (beta != 1.0f) { - post_ops.append_eltwise( - 1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f); - } - funcs::AppendActivation(dev_ctx, post_ops); - dnnl::primitive_attr attrs; - attrs.set_post_ops(post_ops); - - auto x_tz = phi::vectorize(x->dims()); - auto beta_tz = std::vector(x_tz.size(), 1); - auto beta_md = dnnl::memory::desc(beta_tz, - funcs::OneDNNGetDataType(), - funcs::GetPlainOneDNNFormat(x_tz.size())); - - this->AcquireForwardPrimitiveDescriptor(attrs, - dnnl::algorithm::binary_mul, - x->mem_desc(), - beta_md, - x->mem_desc()); - } - - std::shared_ptr AcquireBetaMemory(const float* beta) { - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(), - funcs::to_void_cast(beta)); - } -}; - template void SoftplusKernel(const Context& dev_ctx, const DenseTensor& x, float beta, float threshold, DenseTensor* out) { - SoftplusOneDNNHandler handler(dev_ctx, &x, beta); + funcs::SoftplusOneDNNHandler handler(dev_ctx, &x, beta); auto src_memory_p = handler.AcquireSrcMemory(&x); auto beta_memory_p = handler.AcquireBetaMemory(&beta); @@ -75,7 +36,7 @@ void SoftplusKernel(const Context& dev_ctx, } else { dst_memory_p = handler.AcquireDstMemory(out); } - auto binary_p = handler.AcquireForwardPrimitive(); + auto softplus_p = handler.AcquireForwardPrimitive(); auto& astream = OneDNNContext::tls().get_stream(); @@ -84,7 +45,7 @@ void SoftplusKernel(const Context& dev_ctx, {DNNL_ARG_SRC_1, *beta_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - binary_p->execute(astream, args); + softplus_p->execute(astream, args); astream.wait(); out->set_mem_desc(dst_memory_p->get_desc()); diff --git a/paddle/phi/ops/compat/fused_softplus_sig.cc b/paddle/phi/ops/compat/fused_softplus_sig.cc new file mode 100644 index 00000000000..33a2bddca3f --- /dev/null +++ b/paddle/phi/ops/compat/fused_softplus_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FusedSoftplusOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "fused_softplus", + {"X"}, + {"beta", "threshold", "fuse_activation", "fuse_alpha", "fuse_beta"}, + {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_softplus, phi::FusedSoftplusOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_softplus_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_softplus_activation_fuse_pass.py index c36149114b9..17efc80e22b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_softplus_activation_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_softplus_activation_fuse_pass.py @@ -116,13 +116,13 @@ class TestSoftplusActivationOneDNNFusePass(PassAutoScanTest): def sample_predictor_configs(self, program_config): config = self.create_inference_config(use_mkldnn=True) - yield config, ['softplus'], (1e-5, 1e-5) + yield config, ['fused_softplus'], (1e-5, 1e-5) def test(self): self.run_and_statis( quant=False, max_examples=40, - passes=['softplus_activation_mkldnn_fuse_pass'], + passes=['softplus_activation_onednn_fuse_pass'], ) diff --git a/python/paddle/static/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/static/quantization/quant2_int8_mkldnn_pass.py index dd0855f0268..b77e75ea71f 100644 --- a/python/paddle/static/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/static/quantization/quant2_int8_mkldnn_pass.py @@ -486,7 +486,7 @@ class Quant2Int8MkldnnPass: ) graph = self._apply_pass(graph, 'matmul_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') - graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'softplus_activation_onednn_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass( graph, 'reshape_transpose_matmul_mkldnn_fuse_pass' -- GitLab