未验证 提交 ae40ee32 编写于 作者: J jakpiase 提交者: GitHub

added onednn elu kernel (#37149)

上级 a9e7a854
...@@ -568,6 +568,10 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -568,6 +568,10 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
"The output is a multi-dimensional Tensor which has same " "The output is a multi-dimensional Tensor which has same "
"dimension and data type as the ``x``."); "dimension and data type as the ``x``.");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f); AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
ELU Activation Operator. ELU Activation Operator.
...@@ -743,6 +747,10 @@ class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -743,6 +747,10 @@ class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(6.0f); .SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator") AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f); .SetDefault(3.0f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
HardSwish Activation Operator. HardSwish Activation Operator.
......
...@@ -209,6 +209,10 @@ template <typename T> ...@@ -209,6 +209,10 @@ template <typename T>
using AbsMKLDNNFunctor = using AbsMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>; MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_elu>;
template <typename T> template <typename T>
using ReluMKLDNNGradFunctor = using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>; MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
...@@ -240,6 +244,10 @@ using SqrtMKLDNNGradFunctor = ...@@ -240,6 +244,10 @@ using SqrtMKLDNNGradFunctor =
template <typename T> template <typename T>
using AbsMKLDNNGradFunctor = using AbsMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>; MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_elu>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -264,14 +272,15 @@ namespace ops = paddle::operators; ...@@ -264,14 +272,15 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationGradKernel< \ ops::MKLDNNActivationGradKernel< \
ops::grad_functor<paddle::platform::bfloat16>>); ops::grad_functor<paddle::platform::bfloat16>>);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ #define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \ __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \ __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hardswish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \ __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \ __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \ __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
......
...@@ -326,6 +326,29 @@ class TestMKLDNNSigmoidDim4(TestSigmoid): ...@@ -326,6 +326,29 @@ class TestMKLDNNSigmoidDim4(TestSigmoid):
self.attrs = {"use_mkldnn": True} self.attrs = {"use_mkldnn": True}
class TestMKLDNNEluDefaultAlpha(TestActivation):
def setUp(self):
self.op_type = "elu"
self.set_alpha()
x = np.random.random((5, 5, 4)).astype("float32")
self.inputs = {'X': x}
self.attrs = {'use_mkldnn': True, 'alpha': self.alpha}
self.outputs = {
'Out':
np.maximum(0, x) + np.minimum(0, self.alpha * (np.exp(x) - 1))
}
def set_alpha(self):
self.alpha = 1.0
class TestMKLDNNEluCustomAlpha(TestMKLDNNEluDefaultAlpha):
def set_alpha(self):
self.alpha = 2.5
# Check if primitives already exist in backward # Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase): class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册