未验证 提交 f1275fb6 编写于 作者: S smallv0221 提交者: GitHub

Support dropout backward in eval mode (#35122)

* Support dropout backward in eval mode

* add downscale case

* minor fix

* minor fix
上级 e7df47ec
......@@ -117,10 +117,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DropoutGrad");
......
......@@ -160,17 +160,12 @@ template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
platform::errors::PreconditionNotMet(
"GradOp is only callable when is_test is false"));
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto size = grad_x->numel();
auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);
......@@ -178,6 +173,15 @@ class DropoutGradKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device();
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (context.Attr<bool>("is_test") == true) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
} else {
float dropout_prob = context.Attr<float>("dropout_prob");
dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
}
} else {
auto M = EigenVector<uint8_t>::Flatten(*mask);
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
......@@ -191,9 +195,8 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
context.cuda_device_context(), size);
DropoutGradCUDAKernel<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
DropoutGradCUDAKernel<T, uint8_t, 4><<<
config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
grad_x->data<T>());
#endif
......@@ -206,6 +209,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
dX.device(place) = dY * M.cast<T>();
}
}
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册