You need to sign in or sign up before continuing.
未验证 提交 62aff0a7 编写于 作者: A Adam 提交者: GitHub

Add DNNL GELU kernels (#22426)

上级 009c049e
...@@ -44,8 +44,19 @@ class GeluOp : public framework::OperatorWithKernel { ...@@ -44,8 +44,19 @@ class GeluOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
} }
}; };
...@@ -73,8 +84,19 @@ class GeluGradOp : public framework::OperatorWithKernel { ...@@ -73,8 +84,19 @@ class GeluGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
} }
}; };
......
...@@ -162,6 +162,30 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> { ...@@ -162,6 +162,30 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
} else {
eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const bool approximate = ctx.Attr<bool>("approximate");
if (approximate) {
eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
} else {
eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
}
}
};
template <typename T> template <typename T>
using ReluMKLDNNFunctor = using ReluMKLDNNFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>; MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
...@@ -216,6 +240,7 @@ namespace ops = paddle::operators; ...@@ -216,6 +240,7 @@ namespace ops = paddle::operators;
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ #define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \ __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \ __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \ __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
......
...@@ -19,7 +19,8 @@ import numpy as np ...@@ -19,7 +19,8 @@ import numpy as np
from scipy.special import expit from scipy.special import expit
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish from paddle.fluid.tests.unittests.test_activation_op import TestActivation, TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish
from paddle.fluid.tests.unittests.test_gelu_op import gelu
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
...@@ -59,6 +60,32 @@ class TestMKLDNNLeakyReluDim2(TestLeakyRelu): ...@@ -59,6 +60,32 @@ class TestMKLDNNLeakyReluDim2(TestLeakyRelu):
['X'], 'Out', max_relative_error=0.007, check_dygraph=False) ['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNGeluDim2(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.float32
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, False)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNGeluDim2Approx(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.float32
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, True)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "approximate": True}
class TestMKLDNNTanhDim2(TestTanh): class TestMKLDNNTanhDim2(TestTanh):
def setUp(self): def setUp(self):
super(TestMKLDNNTanhDim2, self).setUp() super(TestMKLDNNTanhDim2, self).setUp()
...@@ -185,6 +212,32 @@ class TestMKLDNNLeakyReluDim4(TestLeakyRelu): ...@@ -185,6 +212,32 @@ class TestMKLDNNLeakyReluDim4(TestLeakyRelu):
['X'], 'Out', max_relative_error=0.007, check_dygraph=False) ['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
class TestMKLDNNGeluDim4(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.float32
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype(self.dtype)
out = gelu(x, False)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNGeluDim4Approx(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.float32
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype(self.dtype)
out = gelu(x, True)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "approximate": True}
class TestMKLDNNTanhDim4(TestTanh): class TestMKLDNNTanhDim4(TestTanh):
def setUp(self): def setUp(self):
super(TestMKLDNNTanhDim4, self).setUp() super(TestMKLDNNTanhDim4, self).setUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册