未验证 提交 0e67fb63 编写于 作者: L Leo Guo 提交者: GitHub

Fix bugs of dropout and dropout_grad. test=kunlun. (#55184)

* Fix bugs of dropout and dropout_grad. test=kunlun

* Modify the code style of dropout_grad_kernel. test=kunlun
上级 bf669a5c
......@@ -39,14 +39,20 @@ void DropoutGradRawKernel(const Context& dev_ctx,
auto* grad_y = &out_grad;
dev_ctx.template Alloc<T>(grad_x);
float dropout_prob = p.to<float>();
const T* mask_data = mask.data<T>();
const uint8_t* mask_data = mask.data<uint8_t>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask.numel());
int r = xpu::cast<uint8_t, XPUType>(
dev_ctx.x_context(), mask_data, mask_tmp_data, mask.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
if (mode != "upscale_in_train") {
int r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
return;
}
......@@ -54,28 +60,25 @@ void DropoutGradRawKernel(const Context& dev_ctx,
auto version =
phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId());
if (version == phi::backends::xpu::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_new = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask.numel());
float scale =
(dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob));
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask.data<T>()),
reinterpret_cast<XPUType*>(mask_new),
mask.numel(),
false,
scale,
0.0f);
r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<XPUType*>(mask_tmp_data),
mask.numel(),
false,
scale,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::mul(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<const XPUType*>(mask_new),
reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
grad_y->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else {
int r =
xpu::dropout_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask.data<T>()),
r = xpu::dropout_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(mask_tmp_data),
reinterpret_cast<const XPUType*>(grad_y->data<T>()),
reinterpret_cast<XPUType*>(grad_x->data<T>()),
dropout_prob,
......
......@@ -62,7 +62,10 @@ void DropoutRawKernel(const Context& dev_ctx,
seed_data = dev_ctx.GetGenerator()->Random64();
}
auto* mask_data = dev_ctx.template Alloc<T>(mask);
auto* mask_data = dev_ctx.template Alloc<uint8_t>(mask);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask->numel());
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
int r = xpu::constant(dev_ctx.x_context(),
......@@ -70,22 +73,26 @@ void DropoutRawKernel(const Context& dev_ctx,
y->numel(),
XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(mask_data),
mask->numel(),
XPUType(0));
r = xpu::constant(
dev_ctx.x_context(), mask_tmp_data, mask->numel(), XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::cast<XPUType, uint8_t>(
dev_ctx.x_context(), mask_tmp_data, mask_data, mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
return;
}
int r = xpu::dropout(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(y->data<T>()),
reinterpret_cast<XPUType*>(mask_data),
mask_tmp_data,
seed_data,
mask->numel(),
is_upscale,
dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
r = xpu::cast<XPUType, uint8_t>(
dev_ctx.x_context(), mask_tmp_data, mask_data, mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
float scale =
(is_upscale) ? (1.0) : (static_cast<float>(1.0f - dropout_prob));
......@@ -107,4 +114,6 @@ PD_REGISTER_KERNEL(dropout,
ALL_LAYOUT,
phi::DropoutRawKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册