未验证 提交 604b6fc0 编写于 作者: L Li Min 提交者: GitHub

fix bug to support dropout eval grad computing. (#37305) (#37331)

fix bug to support dropout eval grad computing. cherry-pick #37305.
上级 44db219a
......@@ -244,34 +244,42 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const std::string dropout_implementation,
float dropout_prob, const Tensor& grad_y,
const Tensor& mask, int64_t size,
Tensor* grad_x) {
auto M = EigenVector<uint8_t>::Flatten(mask);
Tensor* grad_x, bool is_test = false) {
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(grad_y);
auto& place = *dev_ctx.eigen_device();
if (dropout_implementation == "upscale_in_train") {
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
if (is_test) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
} else {
int vec_size = platform::GetVectorizedSize<T>(grad_y.data<T>());
if (vec_size == 4 && size % 4 == 0) {
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = dev_ctx.stream();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, size);
DropoutGradCUDAKernel<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y.data<T>(), mask.data<uint8_t>(), factor, size,
grad_x->data<T>());
dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
}
} else {
auto M = EigenVector<uint8_t>::Flatten(mask);
if (dropout_implementation == "upscale_in_train") {
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
int vec_size = platform::GetVectorizedSize<T>(grad_y.data<T>());
if (vec_size == 4 && size % 4 == 0) {
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = dev_ctx.stream();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, size);
DropoutGradCUDAKernel<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y.data<T>(), mask.data<uint8_t>(), factor, size,
grad_x->data<T>());
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
} else {
dX.device(place) = dY * M.cast<T>();
}
} else {
dX.device(place) = dY * M.cast<T>();
}
}
......
......@@ -34,9 +34,6 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
*increment = offset;
} else if (seed && platform::is_cpu_place(seed->place())) {
*seed_data = *(seed->data<int>());
*increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
*seed_data = seed_offset.first;
......
......@@ -58,10 +58,6 @@ template <typename DeviceContext, typename T>
class GPUDropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
platform::errors::PreconditionNotMet(
"GradOp is only callable when is_test is false"));
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
......@@ -71,10 +67,12 @@ class GPUDropoutGradKernel : public framework::OpKernel<T> {
context.Attr<std::string>("dropout_implementation");
float dropout_prob = context.Attr<float>("dropout_prob");
bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
DropoutGradGPUKernelDriver<T>(dev_ctx, dropout_implementation, dropout_prob,
*grad_y, *mask, size, grad_x);
*grad_y, *mask, size, grad_x, is_test);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册