// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/dropout_impl_util.h" #include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/primitive/compute_primitives.h" #include "paddle/phi/kernels/funcs/dropout_impl.cu.h" namespace phi { template struct NoMaskFwFunctor { const float retain_prob_; const bool is_upscale_in_train_; using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor; HOSTDEVICE inline NoMaskFwFunctor(const float retain_prob, const bool is_upscale_in_train) : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) { factor = static_cast(1.0f / retain_prob_); } HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val, const T2* rand, int num) const { static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; #pragma unroll for (int i = 0; i < kCount; i++) { if (rand[i] < retain_prob_) { dst[i] = is_upscale_in_train_ ? static_cast(static_cast(src_val[i]) * factor) : static_cast(src_val[i]); dst[i] += src_val[i + kCount]; } else { dst[i] = src_val[i + kCount]; } } } }; template struct ScaleAddFuctor { using MT = typename phi::kps::details::MPTypeTrait::Type; explicit ScaleAddFuctor(const MT factor, bool upscale_in_train) : factor_(factor), upscale_in_train_(upscale_in_train) {} __device__ __forceinline__ T operator()(const T src, const T res) const { return upscale_in_train_ ? src + res : static_cast(static_cast(src) * factor_) + res; } private: MT factor_; bool upscale_in_train_; }; template __global__ void VectorizedDropoutForward(const size_t n, uint64_t seed, const T* src, const T* res, T* dst, uint64_t increment, size_t main_offset, Functor functor) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; #ifdef PADDLE_WITH_HIP hiprandStatePhilox4_32_10_t state; hiprand_init(seed, idx + THREAD_ID_X, increment, &state); using SType = hiprandStatePhilox4_32_10_t; #else curandStatePhilox4_32_10_t state; curand_init(seed, idx + THREAD_ID_X, increment, &state); using SType = curandStatePhilox4_32_10_t; #endif T dst_res[kCount * 2]; float rands[kCount]; using Rand = phi::funcs::uniform_distribution; int deal_size = BLOCK_NUM_X * kCount; size_t fix = idx * kCount; for (; fix < main_offset; fix += stride) { kps::ReadData(&dst_res[0], src + fix, deal_size); kps::ReadData(&dst_res[kCount], res + fix, deal_size); kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorTernary( &dst_res[0], &dst_res[0], &rands[0], functor, kCount); kps::WriteData(dst + fix, &dst_res[0], deal_size); if (fix > idx * kCount + 1) { __syncthreads(); } } int remainder = n - fix; if (remainder > 0) { kps::ReadData(&dst_res[0], src + fix, remainder); kps::ReadData(&dst_res[kCount], res + fix, remainder); kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorTernary( &dst_res[0], &dst_res[0], &rands[0], functor, kCount); kps::WriteData(dst + fix, &dst_res[0], remainder); __syncthreads(); } } template void FusedDropoutAddKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const Scalar& p, bool is_test, const std::string& mode, int seed, bool fix_seed, DenseTensor* out, DenseTensor* seed_offset) { auto* out_data = dev_ctx.template Alloc(out); auto* seed_offset_data = dev_ctx.template HostAlloc(seed_offset); int64_t numel = x.numel(); auto stream = dev_ctx.stream(); bool upscale_in_train = (mode == "upscale_in_train"); const auto* x_data = x.data(); const auto* y_data = y.data(); float dropout_rate = p.to(); if (!is_test) { if (dropout_rate == 1.0f) { phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, out); return; } uint64_t seed_data; uint64_t increment; auto random_prop = GetRandomCudaProp(numel, dev_ctx); size_t grid_size = random_prop[0]; size_t block_size = random_prop[1]; size_t offset = random_prop[2]; size_t main_offset = random_prop[3]; funcs::GetSeedDataAndIncrement( dev_ctx, nullptr, fix_seed, seed, offset, &seed_data, &increment); seed_offset_data[0] = static_cast(seed_data); seed_offset_data[1] = static_cast(increment); auto dst_functor = NoMaskFwFunctor(1.0f - dropout_rate, upscale_in_train); #define PD_DROPOUT_KERNEL_NAME \ VectorizedDropoutForward> PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream, offset, KERNEL_PARAMS.As(1), KERNEL_PARAMS.As(5), numel, seed_data, // need save x_data, y_data, out_data, increment, // need save main_offset, dst_functor); #undef PD_DROPOUT_KERNEL_NAME } else { using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_rate); std::vector outs = {out}; std::vector ins = {&x, &y}; phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, ScaleAddFuctor(factor, upscale_in_train)); } } } // namespace phi PD_REGISTER_KERNEL(fused_dropout_add, GPU, ALL_LAYOUT, phi::FusedDropoutAddKernel, float, double, phi::dtype::bfloat16, phi::dtype::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); }