未验证 提交 b81358d1 编写于 作者: S sneaxiy 提交者: GitHub

add dropout fp32 (#39501)

上级 8cedcd3e
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
......@@ -45,6 +46,7 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask, T* dst,
bool is_upscale_in_train, uint64_t increment) {
using MT = typename details::MPTypeTrait<T>::Type;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t state;
......@@ -56,7 +58,7 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
MaskType mask_val;
T dst_val;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
for (; idx < n; idx += blockDim.x * gridDim.x) {
T src_val = src[idx];
#ifdef PADDLE_WITH_HIP
......@@ -68,7 +70,9 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed,
dst_val = 0;
} else {
mask_val = 1;
dst_val = is_upscale_in_train ? src_val * factor : src_val;
dst_val = is_upscale_in_train
? static_cast<T>(static_cast<MT>(src_val) * factor)
: src_val;
}
mask[idx] = mask_val;
dst[idx] = dst_val;
......@@ -81,6 +85,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const T* src, MaskType* mask, T* dst,
bool is_upscale_in_train,
uint64_t increment) {
using MT = typename details::MPTypeTrait<T>::Type;
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
......@@ -94,7 +99,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
curand_init(seed, idx, increment, &state);
#endif
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
LoadT src_val;
platform::Load<T, VecSize>(&src[i], &src_val);
......@@ -114,7 +119,9 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
dst_val[j] = 0;
mask_val[j] = 0;
} else {
dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j];
dst_val[j] = is_upscale_in_train
? static_cast<T>(static_cast<MT>(src_val[j]) * factor)
: src_val[j];
mask_val[j] = 1;
}
}
......@@ -126,21 +133,26 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
explicit CudaDropoutGradFunctor(const T factor) : factor_(factor) {}
using MT = typename details::MPTypeTrait<T>::Type;
explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}
__device__ __forceinline__ T operator()(const T dout,
const MaskType mask) const {
return dout * static_cast<T>(mask) * factor_;
return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
factor_);
}
private:
T factor_;
MT factor_;
};
template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
const T factor, const int64_t size,
T* dx) {
__global__ void DropoutGradCUDAKernel(
const T* dout, const MaskType* mask,
const typename details::MPTypeTrait<T>::Type factor, const int64_t size,
T* dx) {
using MT = typename details::MPTypeTrait<T>::Type;
using LoadT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
......@@ -156,7 +168,8 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
#pragma unroll
for (int j = 0; j < VecSize; j++) {
dx_val[j] = dout_val[j] * static_cast<T>(mask_val[j]) * factor;
dx_val[j] = static_cast<T>(static_cast<MT>(dout_val[j]) *
static_cast<MT>(mask_val[j]) * factor);
}
platform::Store<T, VecSize>(dx_val, &dx[i]);
......@@ -257,6 +270,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
float dropout_prob, const Tensor& grad_y,
const Tensor& mask, int64_t size,
Tensor* grad_x, bool is_test = false) {
using MT = typename details::MPTypeTrait<T>::Type;
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(grad_y);
......@@ -273,7 +287,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
auto stream = dev_ctx.stream();
std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
std::vector<framework::Tensor*> outs = {grad_x};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册