未验证 提交 2dfa88d2 编写于 作者: Y YuanRisheng 提交者: GitHub

[MKLDNN]Move mkldnn activation kernel to phi (#44365)

* move mkldnn activation kernel

* fix compile bugs

* fix compile bugs

* deal with conflict

* fix compile bugs

* fix windows compile bugs

* mkldnn unittest fix

* change mutable to alloc

* fix unittest bugs

* modify code according comment
上级 f419e341
......@@ -510,7 +510,7 @@ function(op_library TARGET)
if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(gelu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
......
......@@ -20,17 +20,18 @@
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
PD_DECLARE_KERNEL(leaky_relu, OneDNN, ALL_LAYOUT);
USE_OP_ITSELF(gelu);
USE_OP_ITSELF(relu);
USE_OP_ITSELF(tanh);
USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
PD_DECLARE_KERNEL(tanh, OneDNN, ALL_LAYOUT);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
......
......@@ -2200,7 +2200,9 @@ Scope* OperatorWithKernel::PrepareData(
(in_def->backend != phi::Backend::GPUDNN ||
tensor_backend != phi::Backend::GPU) &&
(in_def->backend != phi::Backend::KPS ||
tensor_backend != phi::Backend::XPU)) ||
tensor_backend != phi::Backend::XPU) &&
(in_def->backend != phi::Backend::ONEDNN ||
tensor_backend != phi::Backend::CPU)) ||
tensor_in->place().GetType() == AllocationType::GPUPINNED) {
new_expected_kernel_key = std::make_unique<OpKernelType>(
expected_kernel_key.data_type_,
......
......@@ -259,5 +259,5 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
USE_OP_ITSELF(split);
USE_OP_ITSELF(relu);
#ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT);
#endif
......@@ -196,100 +196,21 @@ struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
using ReluMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using Relu6MKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
template <typename T>
using SwishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using HardSwishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using MishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
template <typename T>
using TanhMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
template <typename T>
using SqrtMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
template <typename T>
using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
template <typename T>
using RoundMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
template <typename T>
using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using Relu6MKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
template <typename T>
using SwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using MishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T>
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
template <typename T>
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T>
using AbsMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
} // namespace operators
} // namespace paddle
......@@ -318,24 +239,11 @@ namespace ops = paddle::operators;
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor); \
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(mish, MishMKLDNNFunctor, MishMKLDNNGradFunctor); \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradUseOutFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradUseOutFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
// round eltwise primitive doesn't support BF16, nor does it support grad
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
softplus,
......
......@@ -25,13 +25,14 @@
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(elementwise_mul);
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(conv2d);
......
......@@ -30,7 +30,7 @@
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
......
......@@ -30,7 +30,7 @@
USE_OP_ITSELF(pool2d);
USE_OP_DEVICE_KERNEL(pool2d, MKLDNN);
USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT);
USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
USE_OP_ITSELF(shape);
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h"
namespace paddle {
namespace platform {
......@@ -38,216 +39,8 @@ template <typename T,
typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerNoCachingT {
public:
MKLDNNHandlerNoCachingT(dnnl::engine engine, platform::Place cpu_place)
: engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
return std::make_shared<TForward>(*fwd_pd_);
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
return std::make_shared<TBackward>(*bwd_pd_);
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable("BWD_PD should be set when "
"getting BWD prim ."));
return std::make_shared<TBackward_params>(*bwd_w_pd_);
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_desc(),
to_void_cast<T>(input_data));
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr);
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc());
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(
const framework::Tensor* output) {
const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
to_void_cast<T_out>(output_data));
}
std::shared_ptr<dnnl::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_desc(),
to_void_cast<T>(ptr));
}
std::shared_ptr<dnnl::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr);
}
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
ptr);
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc());
}
protected:
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
typename std::enable_if<std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
fwd_desc, first, engine_);
}
template <class First, class... Args>
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<First>(first),
std::forward<Args>(args)...);
fwd_pd_ =
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(fwd_pd_,
platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed."));
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(fwd_pd_,
platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed."));
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md, void* ptr) {
return std::make_shared<dnnl::memory>(md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md) {
return std::make_shared<dnnl::memory>(md, engine_);
}
void AcquireReorder(const std::shared_ptr<dnnl::memory>& user_memory_p,
const std::shared_ptr<dnnl::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
}
template <typename F = T>
std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
const dnnl::memory::desc& user_md,
const dnnl::memory::desc& target_md,
void* ptr,
bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<dnnl::memory> target_memory_p;
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p = std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<dnnl::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
return target_memory_p;
}
dnnl::engine engine_;
platform::Place place_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
using MKLDNNHandlerNoCachingT = phi::funcs::
MKLDNNHandlerNoCachingT<T, TForward, TBackward, TBackward_params>;
template <typename T,
typename TForward,
......
......@@ -89,17 +89,6 @@ copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h")
file(GLOB kernel_impl_h "impl/*.h" "selected_rows/impl/*.h")
file(GLOB kernel_primitive_h "primitive/*.h")
file(
GLOB
kernel_cc
"*.cc"
"cpu/*.cc"
"selected_rows/*.cc"
"selected_rows/cpu/*.cc"
"sparse/*.cc"
"sparse/cpu/*.cc"
"strings/*.cc"
"strings/cpu/*.cc")
file(
GLOB
......@@ -113,10 +102,34 @@ file(
"strings/*.cu"
"strings/gpu/*.cu")
# file(GLOB kernel_cudnn "gpudnn/*.cu")
# file(GLOB kernel_kps "kps/*.cu")
if(WITH_MKLDNN)
file(
GLOB
kernel_cc
"*.cc"
"cpu/*.cc"
"selected_rows/*.cc"
"selected_rows/cpu/*.cc"
"sparse/*.cc"
"sparse/cpu/*.cc"
"strings/*.cc"
"strings/cpu/*.cc"
"onednn/*.cc")
else()
file(
GLOB
kernel_cc
"*.cc"
"cpu/*.cc"
"selected_rows/*.cc"
"selected_rows/cpu/*.cc"
"sparse/*.cc"
"sparse/cpu/*.cc"
"strings/*.cc"
"strings/cpu/*.cc")
endif()
file(GLOB kernel_xpu "xpu/*.cc")
file(GLOB kernel_onednn "onednn/*.cc")
add_library(phi_cpu ${kernel_cc})
kernel_declare("${kernel_cc}")
......@@ -156,12 +169,4 @@ if(WITH_XPU)
set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_xpu)
endif()
if(WITH_MKLDNN)
add_library(phi_onednn ${kernel_onednn})
kernel_declare(${kernel_onednn})
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_backends)
target_link_libraries(phi_onednn ${COMMON_KERNEL_DEPS})
set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_onednn)
endif()
set_property(GLOBAL PROPERTY PHI_KERNELS ${ADD_PHI_KERNELS})
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace funcs {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = dnnl::memory;
using Place = phi::Place;
template <typename T,
typename TForward,
typename TBackward = paddle::platform::mkldnn_dummy_primitive,
typename TBackward_params = paddle::platform::mkldnn_dummy_primitive>
class MKLDNNHandlerNoCachingT {
public:
MKLDNNHandlerNoCachingT(dnnl::engine engine, Place cpu_place)
: engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) {
phi::OneDNNContext::tls().log_lib_version();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
return std::make_shared<TForward>(*fwd_pd_);
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
return std::make_shared<TBackward>(*bwd_pd_);
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
phi::errors::Unavailable("BWD_PD should be set when "
"getting BWD prim ."));
return std::make_shared<TBackward_params>(*bwd_w_pd_);
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
fwd_pd_->src_desc(), paddle::platform::to_void_cast<T>(input_data));
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output) {
T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr);
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc());
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(const DenseTensor* output) {
const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(
bwd_pd_->dst_desc(),
paddle::platform::to_void_cast<T_out>(output_data));
}
std::shared_ptr<dnnl::memory> AcquireDiffDstMemory(
const DenseTensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_dst_desc(), paddle::platform::to_void_cast<T>(ptr));
}
std::shared_ptr<dnnl::memory> AcquireDiffSrcMemory(DenseTensor* diffsrc) {
T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr);
}
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
DenseTensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
phi::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
ptr);
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
phi::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc());
}
protected:
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
typename std::enable_if<std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
fwd_desc, first, engine_);
}
template <class First, class... Args>
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<First>(first),
std::forward<Args>(args)...);
fwd_pd_ =
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
phi::errors::Unavailable("Get MKLDNN Forward primitive %s failed."));
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
phi::errors::Unavailable("Get MKLDNN Forward primitive %s failed."));
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md, void* ptr) {
return std::make_shared<dnnl::memory>(md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md) {
return std::make_shared<dnnl::memory>(md, engine_);
}
void AcquireReorder(const std::shared_ptr<dnnl::memory>& user_memory_p,
const std::shared_ptr<dnnl::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = phi::OneDNNContext::tls().get_stream();
paddle::platform::RecordEvent record_reorder(
"int_reorder",
paddle::platform::TracerEventType::UserDefined,
2,
paddle::platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
}
template <typename F = T>
std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
const dnnl::memory::desc& user_md,
const dnnl::memory::desc& target_md,
void* ptr,
bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<dnnl::memory> target_memory_p;
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p = std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<dnnl::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = phi::OneDNNContext::tls().get_stream();
paddle::platform::RecordEvent record_reorder(
"int_reorder",
paddle::platform::TracerEventType::UserDefined,
2,
paddle::platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
return target_memory_p;
}
dnnl::engine engine_;
Place place_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
template <typename T>
class ActivationMKLDNNHandler
: public MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward> {
public:
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
float alpha,
float beta,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x)
: MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm,
x->mem_desc(),
alpha,
beta);
}
ActivationMKLDNNHandler(dnnl::algorithm algorithm,
float alpha,
float beta,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
const DenseTensor* dout)
: MKLDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm,
x->mem_desc(),
alpha,
beta);
this->AcquireBackwardPrimitiveDescriptor(
algorithm, dout->mem_desc(), x->mem_desc(), alpha, beta);
}
std::shared_ptr<dnnl::memory> AcquireBackwardSrcMemory(
const DenseTensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
this->bwd_pd_->src_desc(),
paddle::platform::to_void_cast<T>(input_data));
}
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h"
namespace phi {
#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, x, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
name, functor_class, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, x, dout, attr, 0, dx); \
}
#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, out, dout, 0, 0, dx); \
}
#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \
name, functor_class, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx) { \
functor_class<T> functor; \
functor(dev_ctx, out, dout, attr, 0, dx); \
}
template <typename T>
void eltwise_grad(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float alpha,
float beta,
DenseTensor* dx,
dnnl::algorithm algorithm) {
const auto& mkldnn_engine = dev_ctx.GetEngine();
funcs::ActivationMKLDNNHandler<T> handler(
algorithm, alpha, beta, mkldnn_engine, dev_ctx.GetPlace(), &x, &dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(&x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T>
void eltwise_grad_use_out(const OneDNNContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
float beta,
DenseTensor* dx,
dnnl::algorithm algorithm) {
const auto& mkldnn_engine = dev_ctx.GetEngine();
funcs::ActivationMKLDNNHandler<T> handler(
algorithm, alpha, beta, mkldnn_engine, dev_ctx.GetPlace(), &out, &dout);
auto dst_memory_p = handler.AcquireBackwardSrcMemory(&out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_src_memory_p->get_desc());
}
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradFunc : public funcs::BaseActivationFunctor<T> {
void operator()(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float alpha,
float beta,
DenseTensor* dx) const {
eltwise_grad<T>(dev_ctx, x, dout, alpha, beta, dx, algorithm);
}
};
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public funcs::BaseActivationFunctor<T> {
void operator()(const OneDNNContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
float beta,
DenseTensor* dx) const {
eltwise_grad_use_out<T>(dev_ctx, out, dout, alpha, beta, dx, algorithm);
}
};
template <typename T>
using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using SwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using MishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
template <typename T>
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
template <typename T>
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
template <typename T>
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T,
dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhMKLDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtMKLDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
SigmoidMKLDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpMKLDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluMKLDNNGradFunctor);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
ReluMKLDNNGradFunctor,
alpha);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
MishMKLDNNGradFunctor,
threshold);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
SwishMKLDNNGradFunctor,
beta);
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
HardSwishMKLDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx);
}
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
DenseTensor* dx) {
EluMKLDNNGradUseOutFunctor<T> functor;
functor(dev_ctx, out, dout, alpha, 0, dx);
}
} // namespace phi
PD_REGISTER_KERNEL(relu_grad,
OneDNN,
ALL_LAYOUT,
phi::ReluGradKernel,
float,
phi::dtype::bfloat16) {}
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, OneDNN, ALL_LAYOUT, phi::func, float, phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/onednn/mkldnn_reuse.h"
namespace phi {
#define DEFINE_ONEDNN_ACTIVATION_KERNEL(name, functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class<T> functor; \
functor(dev_ctx, x, 0, 0, out); \
}
#define DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr, \
DenseTensor* out) { \
functor_class<T> functor; \
functor(dev_ctx, x, attr, 0, out); \
}
template <typename T>
void EltwiseForward(const OneDNNContext& dev_ctx,
const DenseTensor& x,
float alpha,
float beta,
DenseTensor* out,
dnnl::algorithm algorithm) {
PADDLE_ENFORCE_EQ(paddle::platform::is_cpu_place(dev_ctx.GetPlace()),
true,
phi::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use ONEDNNPlace"));
const auto& mkldnn_engine = dev_ctx.GetEngine();
bool is_inplaced = x.IsSharedBufferWith(*out);
funcs::ActivationMKLDNNHandler<T> handler(
algorithm, alpha, beta, mkldnn_engine, dev_ctx.GetPlace(), &x);
auto src_memory_p = handler.AcquireSrcMemory(&x);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
dev_ctx.template Alloc<T>(out);
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
activation_p->execute(
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationFunc : public funcs::BaseActivationFunctor<T> {
void operator()(const OneDNNContext& dev_ctx,
const DenseTensor& x,
float alpha,
float beta,
DenseTensor* out) const {
EltwiseForward<T>(dev_ctx, x, alpha, beta, out, algorithm);
}
};
template <typename T>
using ReluMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
template <typename T>
using SwishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
template <typename T>
using HardSwishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;
template <typename T>
using MishMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;
template <typename T>
using SigmoidMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
template <typename T>
using TanhMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
template <typename T>
using SqrtMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
template <typename T>
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;
template <typename T>
using RoundMKLDNNFunctor =
MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;
DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluMKLDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhMKLDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpMKLDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtMKLDNNFunctor)
DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidMKLDNNFunctor)
// round eltwise primitive doesn't support BF16, nor does it support grad
DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundMKLDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluMKLDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishMKLDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluMKLDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishMKLDNNFunctor, beta)
template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out) {
HardSwishMKLDNNFunctor<T> functor;
functor(dev_ctx, x, threshold, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, OneDNN, ALL_LAYOUT, phi::func, float, phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册