未验证 提交 87d24878 编写于 作者: H houj04 提交者: GitHub

[XPU] support approximate for gelu activation. (#54376)

上级 58fe161f
...@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so") ...@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so") set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_XFT_LIB_NAME "libxft.so") set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_BASE_DATE "20230529") set(XPU_BASE_DATE "20230602")
set(XPU_XCCL_BASE_VERSION "1.0.49.2") set(XPU_XCCL_BASE_VERSION "1.0.49.2")
set(XPU_XFT_BASE_VERSION "latest") set(XPU_XFT_BASE_VERSION "latest")
......
...@@ -28,16 +28,30 @@ void GeluGradKernel(const Context& dev_ctx, ...@@ -28,16 +28,30 @@ void GeluGradKernel(const Context& dev_ctx,
bool approximate, bool approximate,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
int r = xpu::gelu_grad<XPUType>( if (approximate) {
dev_ctx.x_context(), // int approximate_gelu_grad(Context* ctx, const T* x, const T* y, const T*
reinterpret_cast<const XPUType*>(x.data<T>()), // dy, T* dx, int64_t len);
nullptr, int r = xpu::approximate_gelu_grad<XPUType>(
reinterpret_cast<const XPUType*>(out_grad.data<T>()), dev_ctx.x_context(),
reinterpret_cast<XPUType*>(x_grad->data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
x_grad->numel()); nullptr,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "approximate_gelu_grad");
} else {
// int gelu_grad(Context* ctx, const T* x, const T* y, const T* dy, T* dx,
// int64_t len);
int r = xpu::gelu_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
nullptr,
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad");
}
} }
} // namespace phi } // namespace phi
......
...@@ -27,16 +27,26 @@ void GeluKernel(const Context& dev_ctx, ...@@ -27,16 +27,26 @@ void GeluKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
bool approximate, bool approximate,
DenseTensor* out) { DenseTensor* out) {
if (approximate) {
LOG_FIRST_N(INFO, 1) << "XPU does not support gelu with approximate.";
}
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
int r = xpu::gelu<XPUType>(dev_ctx.x_context(), if (approximate) {
reinterpret_cast<const XPUType*>(x.data<T>()), // int approximate_gelu(Context* ctx, const T* x, T* y, int64_t len, const
reinterpret_cast<XPUType*>(out->data<T>()), // float* max_x = nullptr, float* max_y = nullptr);
out->numel()); int r = xpu::approximate_gelu<XPUType>(
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "approximate_gelu");
} else {
// int gelu(Context* ctx, const T* x, T* y, int64_t len, const float* max_x
// = nullptr, float* max_y = nullptr);
int r = xpu::gelu<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu");
}
} }
} // namespace phi } // namespace phi
......
...@@ -377,6 +377,19 @@ class XPUTestGeluOP(XPUOpTestWrapper): ...@@ -377,6 +377,19 @@ class XPUTestGeluOP(XPUOpTestWrapper):
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = {"approximate": approximate, 'use_xpu': True} self.attrs = {"approximate": approximate, 'use_xpu': True}
class XPUTestGeluApproximate(TestActivationOPBase):
def set_case(self):
self.op_type = "gelu"
self.dtype = self.in_type
approximate = True
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate, 'use_xpu': True}
support_types = get_xpu_op_support_types('gelu') support_types = get_xpu_op_support_types('gelu')
for stype in support_types: for stype in support_types:
......
...@@ -11,9 +11,13 @@ ...@@ -11,9 +11,13 @@
# limitations under the License. # limitations under the License.
import random import random
import sys
import unittest import unittest
import numpy as np import numpy as np
sys.path.append('../rnn')
from convert import get_params_for_net from convert import get_params_for_net
from get_test_cover_info import ( from get_test_cover_info import (
XPUOpTestWrapper, XPUOpTestWrapper,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册