diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 695d29b294a51a79986306ab6a496210c2378fd5..bd4d690577a6fac5be79fb5086d3b5203c5616fb 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -244,34 +244,42 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, const std::string dropout_implementation, float dropout_prob, const Tensor& grad_y, const Tensor& mask, int64_t size, - Tensor* grad_x) { - auto M = EigenVector::Flatten(mask); + Tensor* grad_x, bool is_test = false) { auto dX = EigenVector::Flatten(*grad_x); auto dY = EigenVector::Flatten(grad_y); auto& place = *dev_ctx.eigen_device(); - if (dropout_implementation == "upscale_in_train") { - if (dropout_prob == 1.0f) { - dX.device(place) = static_cast(0) * dY; + if (is_test) { + if (dropout_implementation == "upscale_in_train") { + dX.device(place) = static_cast(1) * dY; } else { - int vec_size = platform::GetVectorizedSize(grad_y.data()); - if (vec_size == 4 && size % 4 == 0) { - auto factor = static_cast(1.0f / (1.0f - dropout_prob)); - auto stream = dev_ctx.stream(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, size); - DropoutGradCUDAKernel< - T, uint8_t, - 4><<>>( - grad_y.data(), mask.data(), factor, size, - grad_x->data()); + dX.device(place) = dY * static_cast(1.0f - dropout_prob); + } + } else { + auto M = EigenVector::Flatten(mask); + if (dropout_implementation == "upscale_in_train") { + if (dropout_prob == 1.0f) { + dX.device(place) = static_cast(0) * dY; } else { - dX.device(place) = - dY * M.cast() / static_cast(1.0f - dropout_prob); + int vec_size = platform::GetVectorizedSize(grad_y.data()); + if (vec_size == 4 && size % 4 == 0) { + auto factor = static_cast(1.0f / (1.0f - dropout_prob)); + auto stream = dev_ctx.stream(); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, size); + DropoutGradCUDAKernel< + T, uint8_t, + 4><<>>( + grad_y.data(), mask.data(), factor, size, + grad_x->data()); + } else { + dX.device(place) = + dY * M.cast() / static_cast(1.0f - dropout_prob); + } } + } else { + dX.device(place) = dY * M.cast(); } - } else { - dX.device(place) = dY * M.cast(); } } diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index e11640d070625e83bcac61c6e174afd3578143a2..f2038d12528c49669e2a12dd13853fd56a17ca49 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -34,9 +34,6 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); *seed_data = static_cast(seed_cpu_tensor.data()[0]); *increment = offset; - } else if (seed && platform::is_cpu_place(seed->place())) { - *seed_data = *(seed->data()); - *increment = offset; } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { auto seed_offset = gen_cuda->IncrementOffset(offset); *seed_data = seed_offset.first; diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 447184b948315ba7fb503cb8a13324c3eee2852c..0d5ee41c5c3a262e9ee606f0a7dadab499cb0bf2 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -58,10 +58,6 @@ template class GPUDropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ(!context.Attr("is_test"), true, - platform::errors::PreconditionNotMet( - "GradOp is only callable when is_test is false")); - auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); @@ -71,10 +67,12 @@ class GPUDropoutGradKernel : public framework::OpKernel { context.Attr("dropout_implementation"); float dropout_prob = context.Attr("dropout_prob"); + bool is_test = context.Attr("is_test"); + auto& dev_ctx = context.template device_context(); DropoutGradGPUKernelDriver(dev_ctx, dropout_implementation, dropout_prob, - *grad_y, *mask, size, grad_x); + *grad_y, *mask, size, grad_x, is_test); } };