diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 94db4c62e391229416ea6b3763177c56b65f4252..83ca9ace20d0541577cab52befa9d359c3f89d21 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -37,8 +37,12 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" + +DECLARE_bool(use_curand); + namespace paddle { namespace operators { + template struct DstMaskGenerator { const float dropout_prob_; @@ -71,13 +75,45 @@ struct DstMaskGenerator { } }; +template +struct DstMaskFunctor { + const float retain_prob_; + const bool is_upscale_in_train_; + using MT = typename details::MPTypeTrait::Type; + MT factor; + HOSTDEVICE inline DstMaskFunctor(const float retain_prob, + const bool is_upscale_in_train) + : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) { + factor = static_cast(1.0f / retain_prob_); + } + + HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val, + const T2* rand, int num) const { + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; +// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask +#pragma unroll + for (int i = 0; i < kCount; i++) { + if (rand[i] < retain_prob_) { + dst[i] = is_upscale_in_train_ + ? static_cast(static_cast(src_val[i]) * factor) + : static_cast(src_val[i]); + dst[i + kCount] = static_cast(1); + } else { + dst[i] = static_cast(0); + dst[i + kCount] = dst[i]; + } + } + } +}; + template __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, MaskType* mask, T* dst, bool is_upscale_in_train, uint64_t increment, - size_t main_offset) { + size_t main_offset, bool use_curand) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; @@ -97,37 +133,78 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, using Rand = phi::funcs::uniform_distribution; using Cast = kps::IdentityFunctor; int deal_size = BLOCK_NUM_X * kCount; - auto dst_functor = - DstMaskGenerator(dropout_prob, is_upscale_in_train); + size_t fix = idx * kCount; - for (; fix < main_offset; fix += stride) { - kps::ReadData(&dst_mask[0], src + fix, deal_size); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], deal_size); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - deal_size); - } - int remainder = n - fix; - if (remainder > 0) { - kps::ReadData(&dst_mask[0], src + fix, remainder); - kps::ElementwiseRandom(&rands[0], Rand(), - &state); - // dst - kps::OperatorTernary>( - &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); - kps::WriteData(dst + fix, &dst_mask[0], remainder); - // mask - kps::ElementwiseUnary( - &mask_result[0], &dst_mask[kCount], Cast()); - kps::WriteData(mask + fix, &mask_result[0], - remainder); + if (use_curand) { + auto dst_functor = + DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], + deal_size); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + deal_size); + if (fix > idx * kCount + 1) { + __syncthreads(); + } + } + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], remainder); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + remainder); + __syncthreads(); + } + } else { + auto dst_functor = + DstMaskGenerator(dropout_prob, is_upscale_in_train); + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], + deal_size); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + deal_size); + } + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], remainder); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + remainder); + } } } @@ -164,31 +241,34 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, return; } - // increment is used to set the args(offset) of curand_init, which defines - // offset in subsequence. - // The detail: - // https://docs.nvidia.com/cuda/curand/device-api-overview.html - // Increment should be at least the number of curand() random numbers used - // in each thread to avoid the random number generated this time being the - // same as the previous calls. uint64_t seed_data; uint64_t increment; - // VectorizedRandomGenerator use curand_uniform4, so we only support - // kVecSize is 4; + // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4; constexpr int kVecSize = phi::funcs::uniform_distribution::kReturnsCount; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize); + size_t grid_size = gpu_config.GetGridSize(); + size_t block_size = gpu_config.GetBlockSize(); + + if (FLAGS_use_curand) { + int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); + const auto& prop = platform::GetDeviceProperties(device_id); + size_t max_grid_size = prop.maxThreadsPerMultiProcessor * + prop.multiProcessorCount / block_size; + grid_size = std::min(grid_size, max_grid_size); + } + auto offset = - ((x_numel - 1) / (gpu_config.GetThreadNum() * kVecSize) + 1) * kVecSize; + ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); - size_t main_offset = size / (gpu_config.GetBlockSize() * kVecSize) * - (gpu_config.GetBlockSize() * kVecSize); - VectorizedRandomGenerator<<< - gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream>>>( + size_t main_offset = + size / (block_size * kVecSize) * (block_size * kVecSize); + + VectorizedRandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment, main_offset); + upscale_in_train, increment, main_offset, FLAGS_use_curand); } else { if (upscale_in_train) { // todo: can y share with data with x directly? diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index fd2f642b770d646e74168800bbe8820534581354..09712005d412539f2e7ca3637e1a584f63508c9f 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -22,6 +22,7 @@ import paddle import paddle.static as static import paddle.fluid as fluid from paddle.fluid import Program, program_guard +import os class TestDropoutOp(OpTest): @@ -992,6 +993,62 @@ class TestDropoutBackward(unittest.TestCase): ), self.cal_grad_upscale_train(mask.numpy(), prob))) +class TestRandomValue(unittest.TestCase): + def test_fixed_random_number(self): + # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' + if not paddle.is_compiled_with_cuda(): + return + + # Different GPU generate different random value. Only test V100 here. + if not "V100" in paddle.device.cuda.get_device_name(): + return + + if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): + return + + print("Test Fixed Random number on V100 GPU------>") + paddle.disable_static() + paddle.set_device('gpu') + paddle.seed(100) + + x = paddle.rand([32, 1024, 1024], dtype='float32') + out = paddle.nn.functional.dropout(x, 0.25).numpy() + index0, index1, index2 = np.nonzero(out) + self.assertEqual(np.sum(index0), 390094540) + self.assertEqual(np.sum(index1), 12871475125) + self.assertEqual(np.sum(index2), 12872777397) + self.assertEqual(np.sum(out), 16778744.0) + expect = [ + 0.6914956, 0.5294584, 0.19032137, 0.6996228, 0.3338527, 0.8442094, + 0.96965003, 1.1726775, 0., 0.28037727 + ] + self.assertTrue(np.allclose(out[10, 100, 500:510], expect)) + + x = paddle.rand([32, 1024, 1024], dtype='float64') + out = paddle.nn.functional.dropout(x).numpy() + index0, index1, index2 = np.nonzero(out) + self.assertEqual(np.sum(index0), 260065137) + self.assertEqual(np.sum(index1), 8582636095) + self.assertEqual(np.sum(index2), 8582219962) + self.assertEqual(np.sum(out), 16778396.563660286) + expect = [ + 1.28587354, 0.15563703, 0., 0.28799703, 0., 0., 0., 0.54964, + 0.51355682, 0.33818988 + ] + self.assertTrue(np.allclose(out[20, 100, 500:510], expect)) + + x = paddle.ones([32, 1024, 1024], dtype='float16') + out = paddle.nn.functional.dropout(x, 0.75).numpy() + index0, index1, index2 = np.nonzero(out) + self.assertEqual(np.sum(index0), 130086900) + self.assertEqual(np.sum(index1), 4291190105) + self.assertEqual(np.sum(index2), 4292243807) + expect = [0., 0., 0., 0., 0., 0., 0., 0., 4., 4.] + self.assertTrue(np.allclose(out[0, 100, 500:510], expect)) + + paddle.enable_static() + + if __name__ == '__main__': paddle.enable_static() unittest.main()