未验证 提交 fdcfa04f 编写于 作者: S Sławomir Siwek 提交者: GitHub

Fused softplus (#51087)

* mkldnn->onednn

* fused softplus op + kernel

* remove extra attributes

* add missing handler

* change var name
上级 b780a3ff
...@@ -172,7 +172,7 @@ if(WITH_MKLDNN) ...@@ -172,7 +172,7 @@ if(WITH_MKLDNN)
pass_library(cpu_bfloat16_pass inference DIR mkldnn) pass_library(cpu_bfloat16_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(interpolate_mkldnn_pass inference DIR mkldnn) pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(softplus_activation_onednn_fuse_pass inference DIR mkldnn)
pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn) pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
......
...@@ -160,7 +160,8 @@ inline void ConvertToFusedOp(OpDesc* op) { ...@@ -160,7 +160,8 @@ inline void ConvertToFusedOp(OpDesc* op) {
{"conv2d", "fused_conv2d"}, {"conv2d", "fused_conv2d"},
{"depthwise_conv2d", "fused_conv2d"}, {"depthwise_conv2d", "fused_conv2d"},
{"matmul", "fused_matmul"}, {"matmul", "fused_matmul"},
{"matmul_v2", "fused_matmul"}}; {"matmul_v2", "fused_matmul"},
{"softplus", "fused_softplus"}};
if (op->Type() == "matmul") { if (op->Type() == "matmul") {
op->SetAttr("trans_x", op->GetAttr("transpose_X")); op->SetAttr("trans_x", op->GetAttr("transpose_X"));
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/softplus_activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -58,6 +59,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( ...@@ -58,6 +59,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation, activation, softplus_activation_pattern); activation, activation, softplus_activation_pattern);
ConvertToFusedOp(softplus->Op());
SetActivationAttrs(softplus->Op(), activation->Op(), act_type); SetActivationAttrs(softplus->Op(), activation->Op(), act_type);
softplus->Op()->SetOutput("Out", {activation_out->Name()}); softplus->Op()->SetOutput("Out", {activation_out->Name()});
...@@ -78,9 +80,9 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( ...@@ -78,9 +80,9 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(softplus_activation_mkldnn_fuse_pass, REGISTER_PASS(softplus_activation_onednn_fuse_pass,
paddle::framework::ir::SoftplusActivationOneDNNPass); paddle::framework::ir::SoftplusActivationOneDNNPass);
REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(softplus_activation_onednn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("softplus", 1) .LE("softplus", 1)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
...@@ -375,7 +375,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -375,7 +375,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"fc_act_mkldnn_fuse_pass", "fc_act_mkldnn_fuse_pass",
"fc_elementwise_add_mkldnn_fuse_pass", // "fc_elementwise_add_mkldnn_fuse_pass", //
"batch_norm_act_fuse_pass", // "batch_norm_act_fuse_pass", //
"softplus_activation_mkldnn_fuse_pass", // "softplus_activation_onednn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", // "shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", // "elt_act_mkldnn_fuse_pass", //
"layer_norm_onednn_optimization_pass", // "layer_norm_onednn_optimization_pass", //
...@@ -467,7 +467,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -467,7 +467,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("fc_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("fc_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass"); passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass");
passes_.push_back("batch_norm_act_fuse_pass"); passes_.push_back("batch_norm_act_fuse_pass");
passes_.push_back("softplus_activation_mkldnn_fuse_pass"); passes_.push_back("softplus_activation_onednn_fuse_pass");
passes_.push_back("compute_propagate_scales_mkldnn_pass"); passes_.push_back("compute_propagate_scales_mkldnn_pass");
passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
......
type: "fused_softplus"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
attrs {
name: "beta"
type: FLOAT
}
attrs {
name: "threshold"
type: FLOAT
}
}
extra {
attrs {
name: "fuse_activation"
type: STRING
}
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class FusedSoftplusOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = this->IndicateVarDataType(ctx, "X");
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
class FusedSoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of softplus operator");
AddOutput("Out", "Output of softplus operator");
AddAttr<float>("beta", "Beta value for the softplus formulation")
.SetDefault(1.0f);
AddAttr<float>("threshold", "Values above this revert to a linear function")
.SetDefault(20.0f);
AddAttr<std::string>(
"fuse_activation",
"Activation type from softplus_activation_onednn_fuse_pass")
.SetDefault("");
AddAttr<float>("fuse_alpha",
"Activation alpha from softplus_activation_onednn_fuse_pass")
.SetDefault(0.0f);
AddAttr<float>("fuse_beta",
"Activation beta from softplus_activation_onednn_fuse_pass")
.SetDefault(0.0f);
AddComment(R"DOC(Softplus extended with oneDNN-specific fusion logic.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_softplus,
ops::FusedSoftplusOp,
ops::FusedSoftplusOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -91,11 +91,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -91,11 +91,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"data_format", ExtraAttrProperty::ONEDNN}, {"data_format", ExtraAttrProperty::ONEDNN},
{"force_fp32_output", ExtraAttrProperty::ONEDNN}, {"force_fp32_output", ExtraAttrProperty::ONEDNN},
{"fuse_activation", ExtraAttrProperty::ONEDNN}, {"fuse_activation", ExtraAttrProperty::ONEDNN},
{"fuse_activation_type", ExtraAttrProperty::ONEDNN},
{"fuse_activation_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_activation_beta", ExtraAttrProperty::ONEDNN},
{"fuse_activation_scale", ExtraAttrProperty::ONEDNN},
{"fused_output_scale", ExtraAttrProperty::ONEDNN},
{"fuse_alpha", ExtraAttrProperty::ONEDNN}, {"fuse_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_beta", ExtraAttrProperty::ONEDNN}, {"fuse_beta", ExtraAttrProperty::ONEDNN},
{"fuse_relu", ExtraAttrProperty::ONEDNN}, {"fuse_relu", ExtraAttrProperty::ONEDNN},
......
...@@ -1622,8 +1622,7 @@ ...@@ -1622,8 +1622,7 @@
outputs : outputs :
out : Out out : Out
extra : extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false, str fuse_activation_type = "", float fuse_activation_alpha = 0.0f, attrs : [bool use_mkldnn = false, bool use_cudnn = false]
float fuse_activation_beta = 0.0f, float fuse_activation_scale = 1.0f]
- op : softshrink - op : softshrink
backward : softshrink_grad backward : softshrink_grad
......
...@@ -1623,6 +1623,47 @@ class PoolingOneDNNHandler ...@@ -1623,6 +1623,47 @@ class PoolingOneDNNHandler
} }
}; };
template <typename T>
class SoftplusOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
SoftplusOneDNNHandler(const OneDNNContext& dev_ctx,
const phi::DenseTensor* x,
const float beta,
const std::string& fuse_activation = "",
const float fuse_alpha = 0.0f,
const float fuse_beta = 0.0f)
: OneDNNHandlerNoCachingT<T, dnnl::binary>(dev_ctx.GetEngine(),
dev_ctx.GetPlace()) {
dnnl::post_ops post_ops;
post_ops.append_eltwise(
1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f);
if (beta != 1.0f) {
post_ops.append_eltwise(
1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f);
}
AppendActivation(
dev_ctx, post_ops, 1.0f, fuse_activation, fuse_alpha, fuse_beta);
dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops);
auto x_tz = phi::vectorize(x->dims());
auto beta_tz = std::vector<int64_t>(x_tz.size(), 1);
auto beta_md = dnnl::memory::desc(
beta_tz, OneDNNGetDataType<T>(), GetPlainOneDNNFormat(x_tz.size()));
this->AcquireForwardPrimitiveDescriptor(attrs,
dnnl::algorithm::binary_mul,
x->mem_desc(),
beta_md,
x->mem_desc());
}
std::shared_ptr<dnnl::memory> AcquireBetaMemory(const float* beta) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
to_void_cast<float>(beta));
}
};
static void SetOutMemDescWithUnsqueeze2FuseSupport( static void SetOutMemDescWithUnsqueeze2FuseSupport(
const std::vector<int> fused_unsqueeze2_axes, const std::vector<int> fused_unsqueeze2_axes,
phi::DenseTensor* out, phi::DenseTensor* out,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedSoftplusKernel(const Context& dev_ctx,
const DenseTensor& x,
float beta,
float threshold,
const std::string& fuse_activation,
const float fuse_alpha,
const float fuse_beta,
DenseTensor* out) {
funcs::SoftplusOneDNNHandler<T> handler(
dev_ctx, &x, beta, fuse_activation, fuse_alpha, fuse_beta);
auto src_memory_p = handler.AcquireSrcMemory(&x);
auto beta_memory_p = handler.AcquireBetaMemory(&beta);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (x.IsSharedBufferWith(*out)) {
dst_memory_p = src_memory_p;
dev_ctx.template Alloc<T>(out);
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto softplus_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_memory_p},
{DNNL_ARG_SRC_1, *beta_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
softplus_p->execute(astream, args);
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(fused_softplus,
OneDNN,
ONEDNN,
phi::FusedSoftplusKernel,
float,
phi::dtype::bfloat16) {}
...@@ -19,52 +19,13 @@ ...@@ -19,52 +19,13 @@
namespace phi { namespace phi {
template <typename T>
class SoftplusOneDNNHandler
: public funcs::OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
SoftplusOneDNNHandler(const OneDNNContext& dev_ctx,
const phi::DenseTensor* x,
const float beta)
: funcs::OneDNNHandlerNoCachingT<T, dnnl::binary>(dev_ctx.GetEngine(),
dev_ctx.GetPlace()) {
dnnl::post_ops post_ops;
post_ops.append_eltwise(
1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f);
if (beta != 1.0f) {
post_ops.append_eltwise(
1.0f, dnnl::algorithm::eltwise_linear, 1.0f / beta, 0.0f);
}
funcs::AppendActivation(dev_ctx, post_ops);
dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops);
auto x_tz = phi::vectorize(x->dims());
auto beta_tz = std::vector<int64_t>(x_tz.size(), 1);
auto beta_md = dnnl::memory::desc(beta_tz,
funcs::OneDNNGetDataType<T>(),
funcs::GetPlainOneDNNFormat(x_tz.size()));
this->AcquireForwardPrimitiveDescriptor(attrs,
dnnl::algorithm::binary_mul,
x->mem_desc(),
beta_md,
x->mem_desc());
}
std::shared_ptr<dnnl::memory> AcquireBetaMemory(const float* beta) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
funcs::to_void_cast<float>(beta));
}
};
template <typename T, typename Context> template <typename T, typename Context>
void SoftplusKernel(const Context& dev_ctx, void SoftplusKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float beta, float beta,
float threshold, float threshold,
DenseTensor* out) { DenseTensor* out) {
SoftplusOneDNNHandler<T> handler(dev_ctx, &x, beta); funcs::SoftplusOneDNNHandler<T> handler(dev_ctx, &x, beta);
auto src_memory_p = handler.AcquireSrcMemory(&x); auto src_memory_p = handler.AcquireSrcMemory(&x);
auto beta_memory_p = handler.AcquireBetaMemory(&beta); auto beta_memory_p = handler.AcquireBetaMemory(&beta);
...@@ -75,7 +36,7 @@ void SoftplusKernel(const Context& dev_ctx, ...@@ -75,7 +36,7 @@ void SoftplusKernel(const Context& dev_ctx,
} else { } else {
dst_memory_p = handler.AcquireDstMemory(out); dst_memory_p = handler.AcquireDstMemory(out);
} }
auto binary_p = handler.AcquireForwardPrimitive(); auto softplus_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
...@@ -84,7 +45,7 @@ void SoftplusKernel(const Context& dev_ctx, ...@@ -84,7 +45,7 @@ void SoftplusKernel(const Context& dev_ctx,
{DNNL_ARG_SRC_1, *beta_memory_p}, {DNNL_ARG_SRC_1, *beta_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
binary_p->execute(astream, args); softplus_p->execute(astream, args);
astream.wait(); astream.wait();
out->set_mem_desc(dst_memory_p->get_desc()); out->set_mem_desc(dst_memory_p->get_desc());
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FusedSoftplusOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"fused_softplus",
{"X"},
{"beta", "threshold", "fuse_activation", "fuse_alpha", "fuse_beta"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_softplus, phi::FusedSoftplusOpArgumentMapping);
...@@ -116,13 +116,13 @@ class TestSoftplusActivationOneDNNFusePass(PassAutoScanTest): ...@@ -116,13 +116,13 @@ class TestSoftplusActivationOneDNNFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['softplus'], (1e-5, 1e-5) yield config, ['fused_softplus'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, quant=False,
max_examples=40, max_examples=40,
passes=['softplus_activation_mkldnn_fuse_pass'], passes=['softplus_activation_onednn_fuse_pass'],
) )
......
...@@ -486,7 +486,7 @@ class Quant2Int8MkldnnPass: ...@@ -486,7 +486,7 @@ class Quant2Int8MkldnnPass:
) )
graph = self._apply_pass(graph, 'matmul_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'matmul_activation_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'softplus_activation_onednn_fuse_pass')
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass( graph = self._apply_pass(
graph, 'reshape_transpose_matmul_mkldnn_fuse_pass' graph, 'reshape_transpose_matmul_mkldnn_fuse_pass'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册