未验证 提交 0e67fb63 编写于 作者: L Leo Guo 提交者: GitHub

Fix bugs of dropout and dropout_grad. test=kunlun. (#55184)

* Fix bugs of dropout and dropout_grad. test=kunlun

* Modify the code style of dropout_grad_kernel. test=kunlun
上级 bf669a5c
...@@ -39,14 +39,20 @@ void DropoutGradRawKernel(const Context& dev_ctx, ...@@ -39,14 +39,20 @@ void DropoutGradRawKernel(const Context& dev_ctx,
auto* grad_y = &out_grad; auto* grad_y = &out_grad;
dev_ctx.template Alloc<T>(grad_x); dev_ctx.template Alloc<T>(grad_x);
float dropout_prob = p.to<float>(); float dropout_prob = p.to<float>();
const T* mask_data = mask.data<T>(); const uint8_t* mask_data = mask.data<uint8_t>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask.numel());
int r = xpu::cast<uint8_t, XPUType>(
dev_ctx.x_context(), mask_data, mask_tmp_data, mask.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
if (mode != "upscale_in_train") { if (mode != "upscale_in_train") {
int r = xpu::mul(dev_ctx.x_context(), r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()), reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_data), reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()), reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel()); grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
return; return;
} }
...@@ -54,28 +60,25 @@ void DropoutGradRawKernel(const Context& dev_ctx, ...@@ -54,28 +60,25 @@ void DropoutGradRawKernel(const Context& dev_ctx,
auto version = auto version =
phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId());
if (version == phi::backends::xpu::XPUVersion::XPU1) { if (version == phi::backends::xpu::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_new = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask.numel());
float scale = float scale =
(dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob)); (dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob));
int r = xpu::scale(dev_ctx.x_context(), r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask.data<T>()), reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<XPUType*>(mask_new), reinterpret_cast<XPUType*>(mask_tmp_data),
mask.numel(), mask.numel(),
false, false,
scale, scale,
0.0f); 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::mul(dev_ctx.x_context(), r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()), reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_new), reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()), reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel()); grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else { } else {
int r = r = xpu::dropout_grad(dev_ctx.x_context(),
xpu::dropout_grad(dev_ctx.x_context(), reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<const XPUType*>(mask.data<T>()),
reinterpret_cast<const XPUType*>(grad_y->data<T>()), reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<XPUType*>(grad_x->data<T>()), reinterpret_cast<XPUType*>(grad_x->data<T>()),
dropout_prob, dropout_prob,
......
...@@ -62,7 +62,10 @@ void DropoutRawKernel(const Context& dev_ctx, ...@@ -62,7 +62,10 @@ void DropoutRawKernel(const Context& dev_ctx,
seed_data = dev_ctx.GetGenerator()->Random64(); seed_data = dev_ctx.GetGenerator()->Random64();
} }
auto* mask_data = dev_ctx.template Alloc<T>(mask); auto* mask_data = dev_ctx.template Alloc<uint8_t>(mask);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask->numel());
// Special case when dropout_prob is 1.0 // Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) { if (dropout_prob == 1.0f) {
int r = xpu::constant(dev_ctx.x_context(), int r = xpu::constant(dev_ctx.x_context(),
...@@ -70,22 +73,26 @@ void DropoutRawKernel(const Context& dev_ctx, ...@@ -70,22 +73,26 @@ void DropoutRawKernel(const Context& dev_ctx,
y->numel(), y->numel(),
XPUType(0)); XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(dev_ctx.x_context(), r = xpu::constant(
reinterpret_cast<XPUType*>(mask_data), dev_ctx.x_context(), mask_tmp_data, mask->numel(), XPUType(0));
mask->numel(),
XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::cast<XPUType, uint8_t>(
dev_ctx.x_context(), mask_tmp_data, mask_data, mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
return; return;
} }
int r = xpu::dropout(dev_ctx.x_context(), int r = xpu::dropout(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(y->data<T>()), reinterpret_cast<XPUType*>(y->data<T>()),
reinterpret_cast<XPUType*>(mask_data), mask_tmp_data,
seed_data, seed_data,
mask->numel(), mask->numel(),
is_upscale, is_upscale,
dropout_prob); dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
r = xpu::cast<XPUType, uint8_t>(
dev_ctx.x_context(), mask_tmp_data, mask_data, mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else { } else {
float scale = float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob)); (is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
...@@ -107,4 +114,6 @@ PD_REGISTER_KERNEL(dropout, ...@@ -107,4 +114,6 @@ PD_REGISTER_KERNEL(dropout,
ALL_LAYOUT, ALL_LAYOUT,
phi::DropoutRawKernel, phi::DropoutRawKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册