未验证 提交 604b6fc0 编写于 作者: L Li Min 提交者: GitHub

fix bug to support dropout eval grad computing. (#37305) (#37331)

fix bug to support dropout eval grad computing. cherry-pick #37305.
上级 44db219a
...@@ -244,12 +244,19 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -244,12 +244,19 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const std::string dropout_implementation, const std::string dropout_implementation,
float dropout_prob, const Tensor& grad_y, float dropout_prob, const Tensor& grad_y,
const Tensor& mask, int64_t size, const Tensor& mask, int64_t size,
Tensor* grad_x) { Tensor* grad_x, bool is_test = false) {
auto M = EigenVector<uint8_t>::Flatten(mask);
auto dX = EigenVector<T>::Flatten(*grad_x); auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(grad_y); auto dY = EigenVector<T>::Flatten(grad_y);
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
if (is_test) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
} else {
dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
}
} else {
auto M = EigenVector<uint8_t>::Flatten(mask);
if (dropout_implementation == "upscale_in_train") { if (dropout_implementation == "upscale_in_train") {
if (dropout_prob == 1.0f) { if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY; dX.device(place) = static_cast<T>(0) * dY;
...@@ -273,6 +280,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -273,6 +280,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
} else { } else {
dX.device(place) = dY * M.cast<T>(); dX.device(place) = dY * M.cast<T>();
} }
}
} }
} // namespace operators } // namespace operators
......
...@@ -34,9 +34,6 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, ...@@ -34,9 +34,6 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]); *seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
*increment = offset; *increment = offset;
} else if (seed && platform::is_cpu_place(seed->place())) {
*seed_data = *(seed->data<int>());
*increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset); auto seed_offset = gen_cuda->IncrementOffset(offset);
*seed_data = seed_offset.first; *seed_data = seed_offset.first;
......
...@@ -58,10 +58,6 @@ template <typename DeviceContext, typename T> ...@@ -58,10 +58,6 @@ template <typename DeviceContext, typename T>
class GPUDropoutGradKernel : public framework::OpKernel<T> { class GPUDropoutGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
platform::errors::PreconditionNotMet(
"GradOp is only callable when is_test is false"));
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X")); auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out")); auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask"); auto* mask = context.Input<Tensor>("Mask");
...@@ -71,10 +67,12 @@ class GPUDropoutGradKernel : public framework::OpKernel<T> { ...@@ -71,10 +67,12 @@ class GPUDropoutGradKernel : public framework::OpKernel<T> {
context.Attr<std::string>("dropout_implementation"); context.Attr<std::string>("dropout_implementation");
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>(); context.template device_context<platform::CUDADeviceContext>();
DropoutGradGPUKernelDriver<T>(dev_ctx, dropout_implementation, dropout_prob, DropoutGradGPUKernelDriver<T>(dev_ctx, dropout_implementation, dropout_prob,
*grad_y, *mask, size, grad_x); *grad_y, *mask, size, grad_x, is_test);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册