未验证 提交 14abafa1 编写于 作者: H houj04 提交者: GitHub

[XPU] layer_norm support fp16 input of scale and bias. (#52091)

上级 019e1cf5
...@@ -41,18 +41,56 @@ void LayerNormGradKernel(const Context& ctx, ...@@ -41,18 +41,56 @@ void LayerNormGradKernel(const Context& ctx,
const auto* out_grad_data = out_grad.data<T>(); const auto* out_grad_data = out_grad.data<T>();
const auto* mean_data = mean.data<float>(); const auto* mean_data = mean.data<float>();
const auto* variance_data = variance.data<float>(); const auto* variance_data = variance.data<float>();
const auto* scale_data =
(scale.get_ptr() == nullptr ? nullptr : scale.get_ptr()->data<float>()); xpu::ctx_guard RAII_GUARD(ctx.x_context());
auto* scale_grad_data =
(scale_grad == nullptr ? nullptr : ctx.template Alloc<float>(scale_grad)); // scale
auto* bias_grad_data = const float* scale_data_fp32 = nullptr;
(bias_grad == nullptr ? nullptr : ctx.template Alloc<float>(bias_grad)); float* scale_grad_data_fp32 = nullptr;
const auto* scale_ptr = scale.get_ptr();
bool need_cast_scale = false;
if (scale_ptr == nullptr) {
// no scale, do nothing
} else if (scale_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* scale_data_temp =
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_data_temp,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp;
need_cast_scale = true;
scale_grad_data_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
scale_grad_data_fp32 = ctx.template Alloc<float>(scale_grad);
}
// bias
float* bias_grad_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
bool need_cast_bias = false;
if (bias_ptr == nullptr) {
// no bias, do nothing
} else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
need_cast_bias = true;
bias_grad_data_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
} else {
// no need to cast
bias_grad_data_fp32 = ctx.template Alloc<float>(bias_grad);
}
auto* x_grad_data = auto* x_grad_data =
(x_grad == nullptr ? nullptr : ctx.template Alloc<T>(x_grad)); (x_grad == nullptr ? nullptr : ctx.template Alloc<T>(x_grad));
// int layer_norm_grad(Context* ctx, const T* x, const T* dy, T* dx, int m, // int layer_norm_grad(Context* ctx, const T* x, const T* dy, T* dx, int64_t
// int n, float eps, const float* scale, const float* mean, const float* // m, int64_t n, float eps, const float* scale, const float* mean, const
// var, float* dscale, float* dbias); // float* var, float* dscale, float* dbias, bool is_rstd = false);
int r = xpu::layer_norm_grad(ctx.x_context(), int r = xpu::layer_norm_grad(ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(out_grad_data), reinterpret_cast<const XPUType*>(out_grad_data),
...@@ -60,12 +98,29 @@ void LayerNormGradKernel(const Context& ctx, ...@@ -60,12 +98,29 @@ void LayerNormGradKernel(const Context& ctx,
left, left,
right, right,
epsilon, epsilon,
scale_data, scale_data_fp32,
mean_data, mean_data,
variance_data, variance_data,
scale_grad_data, scale_grad_data_fp32,
bias_grad_data); bias_grad_data_fp32);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
if (need_cast_scale) {
int r = xpu::cast<float, XPUType>(
ctx.x_context(),
scale_grad_data_fp32,
reinterpret_cast<XPUType*>(ctx.template Alloc<T>(scale_grad)),
scale.get_ptr()->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
if (need_cast_bias) {
int r = xpu::cast<float, XPUType>(
ctx.x_context(),
bias_grad_data_fp32,
reinterpret_cast<XPUType*>(ctx.template Alloc<T>(bias_grad)),
bias.get_ptr()->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
} }
} // namespace phi } // namespace phi
......
...@@ -35,24 +35,65 @@ void LayerNormKernel(const Context& ctx, ...@@ -35,24 +35,65 @@ void LayerNormKernel(const Context& ctx,
int left = static_cast<int>(matrix_dim[0]); int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
const auto* x_data = x.data<T>(); const auto* x_data = x.data<T>();
const auto* scale_data =
(scale.get_ptr() == nullptr ? nullptr : scale.get_ptr()->data<float>()); xpu::ctx_guard RAII_GUARD(ctx.x_context());
const auto* bias_data =
(bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>()); // scale
const float* scale_data_fp32 = nullptr;
const auto* scale_ptr = scale.get_ptr();
if (scale_ptr == nullptr) {
// no scale, do nothing
} else if (scale_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* scale_data_temp =
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_data_temp,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp;
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
}
// bias
const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
if (bias_ptr == nullptr) {
// no bias, do nothing
} else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
bias_data_temp,
bias_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data_fp32 = bias_data_temp;
} else {
// no need to cast
bias_data_fp32 = bias_ptr->data<float>();
}
auto* out_data = ctx.template Alloc<T>(out); auto* out_data = ctx.template Alloc<T>(out);
auto* mean_data = ctx.template Alloc<float>(mean); auto* mean_data = ctx.template Alloc<float>(mean);
auto* variance_data = ctx.template Alloc<float>(variance); auto* variance_data = ctx.template Alloc<float>(variance);
// int layer_norm(Context* ctx, const T* x, T* y, int m, int n, float eps, // int layer_norm(Context* ctx, const T* x, T* y, int64_t m, int64_t n, float
// const float* scale, const float* bias, float* mean, float* var); // eps, const float* scale, const float* bias, float* mean, float* var, bool
// is_rstd = false);
int r = xpu::layer_norm(ctx.x_context(), int r = xpu::layer_norm(ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data), reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(out_data), reinterpret_cast<XPUType*>(out_data),
left, left,
right, right,
epsilon, epsilon,
scale_data, scale_data_fp32,
bias_data, bias_data_fp32,
mean_data, mean_data,
variance_data); variance_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
......
...@@ -63,6 +63,7 @@ class XPUTestLayerNormOp(XPUOpTestWrapper): ...@@ -63,6 +63,7 @@ class XPUTestLayerNormOp(XPUOpTestWrapper):
self.shape = [2, 3, 4, 5] self.shape = [2, 3, 4, 5]
self.epsilon = 1e-05 self.epsilon = 1e-05
self.begin_norm_axis = 1 self.begin_norm_axis = 1
self.use_fp16_scale_bias = False
self.set_attrs() self.set_attrs()
self.atol = 1e-4 self.atol = 1e-4
...@@ -76,6 +77,9 @@ class XPUTestLayerNormOp(XPUOpTestWrapper): ...@@ -76,6 +77,9 @@ class XPUTestLayerNormOp(XPUOpTestWrapper):
x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
scale_np = np.random.uniform(0.1, 1, [right]).astype('float32') scale_np = np.random.uniform(0.1, 1, [right]).astype('float32')
bias_np = np.random.uniform(0.1, 1, [right]).astype('float32') bias_np = np.random.uniform(0.1, 1, [right]).astype('float32')
if self.dtype == np.float16 and self.use_fp16_scale_bias:
scale_np = scale_np.astype('float16')
bias_np = scale_np.astype('float16')
ref_y_np, ref_mean_np, ref_variance_np = ref_layer_norm( ref_y_np, ref_mean_np, ref_variance_np = ref_layer_norm(
x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis
) )
...@@ -119,6 +123,20 @@ class XPUTestLayerNormOp(XPUOpTestWrapper): ...@@ -119,6 +123,20 @@ class XPUTestLayerNormOp(XPUOpTestWrapper):
def set_attrs(self): def set_attrs(self):
self.shape = [4, 5, 6] self.shape = [4, 5, 6]
class TestXPULayerNormOpFP16(TestXPULayerNormOp):
def set_attrs(self):
self.use_fp16_scale_bias = True
class TestXPULayerNormOpFP16_2D(TestXPULayerNormOp):
def set_attrs(self):
self.shape = [10, 12]
self.use_fp16_scale_bias = True
class TestXPULayerNormOpFP16_3D(TestXPULayerNormOp):
def set_attrs(self):
self.shape = [4, 5, 6]
self.use_fp16_scale_bias = True
support_types = get_xpu_op_support_types('layer_norm') support_types = get_xpu_op_support_types('layer_norm')
for stype in support_types: for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册