diff --git a/paddle/phi/kernels/xpu/dropout_grad_kernel.cc b/paddle/phi/kernels/xpu/dropout_grad_kernel.cc index c6803ca5cfcbd86f7d96b948ac02ec055e6730fd..70ff7794b649dc56c1a872695067b5d8df2ce7f0 100644 --- a/paddle/phi/kernels/xpu/dropout_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/dropout_grad_kernel.cc @@ -39,14 +39,20 @@ void DropoutGradRawKernel(const Context& dev_ctx, auto* grad_y = &out_grad; dev_ctx.template Alloc(grad_x); float dropout_prob = p.to(); - const T* mask_data = mask.data(); + const uint8_t* mask_data = mask.data(); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm(mask.numel()); + int r = xpu::cast( + dev_ctx.x_context(), mask_data, mask_tmp_data, mask.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); if (mode != "upscale_in_train") { - int r = xpu::mul(dev_ctx.x_context(), - reinterpret_cast(grad_y->data()), - reinterpret_cast(mask_data), - reinterpret_cast(grad_x->data()), - grad_y->numel()); + r = xpu::mul(dev_ctx.x_context(), + reinterpret_cast(grad_y->data()), + reinterpret_cast(mask_tmp_data), + reinterpret_cast(grad_x->data()), + grad_y->numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); return; } @@ -54,28 +60,25 @@ void DropoutGradRawKernel(const Context& dev_ctx, auto version = phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); if (version == phi::backends::xpu::XPUVersion::XPU1) { - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - XPUType* mask_new = RAII_GUARD.alloc_l3_or_gm(mask.numel()); float scale = (dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob)); - int r = xpu::scale(dev_ctx.x_context(), - reinterpret_cast(mask.data()), - reinterpret_cast(mask_new), - mask.numel(), - false, - scale, - 0.0f); + r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(mask_tmp_data), + reinterpret_cast(mask_tmp_data), + mask.numel(), + false, + scale, + 0.0f); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); r = xpu::mul(dev_ctx.x_context(), reinterpret_cast(grad_y->data()), - reinterpret_cast(mask_new), + reinterpret_cast(mask_tmp_data), reinterpret_cast(grad_x->data()), grad_y->numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); } else { - int r = - xpu::dropout_grad(dev_ctx.x_context(), - reinterpret_cast(mask.data()), + r = xpu::dropout_grad(dev_ctx.x_context(), + reinterpret_cast(mask_tmp_data), reinterpret_cast(grad_y->data()), reinterpret_cast(grad_x->data()), dropout_prob, diff --git a/paddle/phi/kernels/xpu/dropout_kernel.cc b/paddle/phi/kernels/xpu/dropout_kernel.cc index 2f8a09afde8073ac589599fec313c8a1bd45d16f..54dcc1b3c637951e840a98620bf693379adb5801 100644 --- a/paddle/phi/kernels/xpu/dropout_kernel.cc +++ b/paddle/phi/kernels/xpu/dropout_kernel.cc @@ -62,7 +62,10 @@ void DropoutRawKernel(const Context& dev_ctx, seed_data = dev_ctx.GetGenerator()->Random64(); } - auto* mask_data = dev_ctx.template Alloc(mask); + auto* mask_data = dev_ctx.template Alloc(mask); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm(mask->numel()); // Special case when dropout_prob is 1.0 if (dropout_prob == 1.0f) { int r = xpu::constant(dev_ctx.x_context(), @@ -70,22 +73,26 @@ void DropoutRawKernel(const Context& dev_ctx, y->numel(), XPUType(0)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - r = xpu::constant(dev_ctx.x_context(), - reinterpret_cast(mask_data), - mask->numel(), - XPUType(0)); + r = xpu::constant( + dev_ctx.x_context(), mask_tmp_data, mask->numel(), XPUType(0)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::cast( + dev_ctx.x_context(), mask_tmp_data, mask_data, mask->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); return; } int r = xpu::dropout(dev_ctx.x_context(), reinterpret_cast(x.data()), reinterpret_cast(y->data()), - reinterpret_cast(mask_data), + mask_tmp_data, seed_data, mask->numel(), is_upscale, dropout_prob); PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); + r = xpu::cast( + dev_ctx.x_context(), mask_tmp_data, mask_data, mask->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); } else { float scale = (is_upscale) ? (1.0) : (static_cast(1.0f - dropout_prob)); @@ -107,4 +114,6 @@ PD_REGISTER_KERNEL(dropout, ALL_LAYOUT, phi::DropoutRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UINT8); +}