未验证 提交 87d24878 编写于 作者: H houj04 提交者: GitHub

[XPU] support approximate for gelu activation. (#54376)

上级 58fe161f
......@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_BASE_DATE "20230529")
set(XPU_BASE_DATE "20230602")
set(XPU_XCCL_BASE_VERSION "1.0.49.2")
set(XPU_XFT_BASE_VERSION "latest")
......
......@@ -28,16 +28,30 @@ void GeluGradKernel(const Context& dev_ctx,
bool approximate,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(x_grad);
int r = xpu::gelu_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
nullptr,
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad");
if (approximate) {
// int approximate_gelu_grad(Context* ctx, const T* x, const T* y, const T*
// dy, T* dx, int64_t len);
int r = xpu::approximate_gelu_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
nullptr,
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "approximate_gelu_grad");
} else {
// int gelu_grad(Context* ctx, const T* x, const T* y, const T* dy, T* dx,
// int64_t len);
int r = xpu::gelu_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
nullptr,
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad");
}
}
} // namespace phi
......
......@@ -27,16 +27,26 @@ void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
if (approximate) {
LOG_FIRST_N(INFO, 1) << "XPU does not support gelu with approximate.";
}
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
int r = xpu::gelu<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu");
if (approximate) {
// int approximate_gelu(Context* ctx, const T* x, T* y, int64_t len, const
// float* max_x = nullptr, float* max_y = nullptr);
int r = xpu::approximate_gelu<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "approximate_gelu");
} else {
// int gelu(Context* ctx, const T* x, T* y, int64_t len, const float* max_x
// = nullptr, float* max_y = nullptr);
int r = xpu::gelu<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu");
}
}
} // namespace phi
......
......@@ -377,6 +377,19 @@ class XPUTestGeluOP(XPUOpTestWrapper):
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate, 'use_xpu': True}
class XPUTestGeluApproximate(TestActivationOPBase):
def set_case(self):
self.op_type = "gelu"
self.dtype = self.in_type
approximate = True
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate, 'use_xpu': True}
support_types = get_xpu_op_support_types('gelu')
for stype in support_types:
......
......@@ -11,9 +11,13 @@
# limitations under the License.
import random
import sys
import unittest
import numpy as np
sys.path.append('../rnn')
from convert import get_params_for_net
from get_test_cover_info import (
XPUOpTestWrapper,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册