diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 6ee19c939f2613b6152a90284681422929b2727e..37e622e768d37afcc98b6404761a3508d6cbba58 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -117,10 +117,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_test"), false, - platform::errors::InvalidArgument( - "GradOp is only callable when is_test is false")); - OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", framework::GradVarName("Out"), "DropoutGrad"); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 69c420e2c93ed37bf0733ae3c67de1242c04701a..997a7d835aa37b40c619b9521c8ca50a474423ec 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -160,17 +160,12 @@ template class DropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ(!context.Attr("is_test"), true, - platform::errors::PreconditionNotMet( - "GradOp is only callable when is_test is false")); - auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); auto size = grad_x->numel(); - auto M = EigenVector::Flatten(*mask); auto dX = EigenVector::Flatten(*grad_x); auto dY = EigenVector::Flatten(*grad_y); @@ -178,32 +173,41 @@ class DropoutGradKernel : public framework::OpKernel { *context.template device_context().eigen_device(); auto& dropout_implementation = context.Attr("dropout_implementation"); - if (dropout_implementation == "upscale_in_train") { - float dropout_prob = context.Attr("dropout_prob"); - if (dropout_prob == 1.0f) { - dX.device(place) = static_cast(0) * dY; + if (context.Attr("is_test") == true) { + if (dropout_implementation == "upscale_in_train") { + dX.device(place) = static_cast(1) * dY; } else { - int vec_size = VectorizedSize(grad_y->data()); - if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && - size % 4 == 0) { + float dropout_prob = context.Attr("dropout_prob"); + dX.device(place) = dY * static_cast(1.0f - dropout_prob); + } + } else { + auto M = EigenVector::Flatten(*mask); + if (dropout_implementation == "upscale_in_train") { + float dropout_prob = context.Attr("dropout_prob"); + if (dropout_prob == 1.0f) { + dX.device(place) = static_cast(0) * dY; + } else { + int vec_size = VectorizedSize(grad_y->data()); + if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && + size % 4 == 0) { #if defined(__NVCC__) || defined(__HIPCC__) - auto factor = static_cast(1.0f / (1.0f - dropout_prob)); - auto stream = context.cuda_device_context().stream(); - platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( - context.cuda_device_context(), size); - DropoutGradCUDAKernel< - T, uint8_t, - 4><<>>( - grad_y->data(), mask->data(), factor, size, - grad_x->data()); + auto factor = static_cast(1.0f / (1.0f - dropout_prob)); + auto stream = context.cuda_device_context().stream(); + platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( + context.cuda_device_context(), size); + DropoutGradCUDAKernel<<< + config.block_per_grid, config.thread_per_block, 0, stream>>>( + grad_y->data(), mask->data(), factor, size, + grad_x->data()); #endif - } else { - dX.device(place) = - dY * M.cast() / static_cast(1.0f - dropout_prob); + } else { + dX.device(place) = + dY * M.cast() / static_cast(1.0f - dropout_prob); + } } + } else { + dX.device(place) = dY * M.cast(); } - } else { - dX.device(place) = dY * M.cast(); } } };