未验证 提交 f1275fb6 编写于 作者: S smallv0221 提交者: GitHub

Support dropout backward in eval mode (#35122)

* Support dropout backward in eval mode

* add downscale case

* minor fix

* minor fix
上级 e7df47ec
...@@ -117,10 +117,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -117,10 +117,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad"); OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DropoutGrad"); framework::GradVarName("Out"), "DropoutGrad");
......
...@@ -160,17 +160,12 @@ template <typename DeviceContext, typename T> ...@@ -160,17 +160,12 @@ template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> { class DropoutGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
platform::errors::PreconditionNotMet(
"GradOp is only callable when is_test is false"));
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X")); auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out")); auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask"); auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace()); grad_x->mutable_data<T>(context.GetPlace());
auto size = grad_x->numel(); auto size = grad_x->numel();
auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x); auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y); auto dY = EigenVector<T>::Flatten(*grad_y);
...@@ -178,32 +173,41 @@ class DropoutGradKernel : public framework::OpKernel<T> { ...@@ -178,32 +173,41 @@ class DropoutGradKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto& dropout_implementation = auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation"); context.Attr<std::string>("dropout_implementation");
if (dropout_implementation == "upscale_in_train") { if (context.Attr<bool>("is_test") == true) {
float dropout_prob = context.Attr<float>("dropout_prob"); if (dropout_implementation == "upscale_in_train") {
if (dropout_prob == 1.0f) { dX.device(place) = static_cast<T>(1) * dY;
dX.device(place) = static_cast<T>(0) * dY;
} else { } else {
int vec_size = VectorizedSize<T>(grad_y->data<T>()); float dropout_prob = context.Attr<float>("dropout_prob");
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
size % 4 == 0) { }
} else {
auto M = EigenVector<uint8_t>::Flatten(*mask);
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
int vec_size = VectorizedSize<T>(grad_y->data<T>());
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob)); auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
context.cuda_device_context(), size); context.cuda_device_context(), size);
DropoutGradCUDAKernel< DropoutGradCUDAKernel<T, uint8_t, 4><<<
T, uint8_t, config.block_per_grid, config.thread_per_block, 0, stream>>>(
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>( grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
grad_y->data<T>(), mask->data<uint8_t>(), factor, size, grad_x->data<T>());
grad_x->data<T>());
#endif #endif
} else { } else {
dX.device(place) = dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob); dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
} }
} else {
dX.device(place) = dY * M.cast<T>();
} }
} else {
dX.device(place) = dY * M.cast<T>();
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册