未验证 提交 3474e09c 编写于 作者: B Bo Zhang 提交者: GitHub

Support different dtypes of inputs for broadcast for dropout optimization (#52093)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* PR comment
上级 fe053396
......@@ -459,41 +459,43 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
// y = factor * x
ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor);
} else {
phi::DenseTensor broadcasted_mask;
if (is_dropout_nd) {
broadcasted_mask.Resize(grad_y.dims());
dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
}
std::vector<const phi::DenseTensor*> ins = {
&grad_y, is_dropout_nd ? &broadcasted_mask : &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
if (upscale_in_train) {
if (dropout_prob == 1.0f) {
if (upscale_in_train && dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#else
cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#endif
} else {
MT factor = upscale_in_train
? static_cast<MT>(1.0f / (1.0f - dropout_prob))
: static_cast<MT>(1.0f);
if (is_dropout_nd) {
phi::DenseTensor broadcasted_mask;
broadcasted_mask.Resize(grad_y.dims());
dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
std::vector<const phi::DenseTensor*> ins = {&grad_y, &broadcasted_mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} else {
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
std::vector<const phi::DenseTensor*> ins = {&grad_y, &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
}
} else {
MT factor = static_cast<MT>(1.0f);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
}
}
}
......
......@@ -35,7 +35,7 @@ namespace kps = phi::kps;
namespace phi {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
......@@ -508,15 +508,31 @@ struct Unroller<Func, VecSize, End, End> {
static HOSTDEVICE inline void step(Args &&...args) {}
};
// static unroller without VecSize for broadcast
template <template <int Index> typename Func, int End, int Begin = 0>
struct UnrollerWithoutVecSize {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {
Func<Begin>::Apply(std::forward<Args>(args)...);
UnrollerWithoutVecSize<Func, End, Begin + 1>::step(args...);
}
};
template <template <int Index> typename Func, int End>
struct UnrollerWithoutVecSize<Func, End, End> {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {}
};
template <int Index, int VecSize>
struct Loader {
template <typename Array, typename ArgsT>
static __device__ void Apply(const Array &in,
ArgsT *args,
kps::IndexType offset,
int num,
int read_lens,
bool is_boundary) {
static __device__ __forceinline__ void Apply(const Array &in,
ArgsT *args,
kps::IndexType offset,
int num,
int read_lens,
bool is_boundary) {
using Type = std::tuple_element_t<Index, ArgsT>;
kps::Init<Type, ArgsT, Index, VecSize>(
args, static_cast<Type>(1.0f), read_lens);
......@@ -536,7 +552,7 @@ struct Loader {
}
};
template <int Index, int VecSize>
template <int Index>
struct InputSetter {
template <typename Array>
static HOSTDEVICE void Apply(
......@@ -545,7 +561,7 @@ struct InputSetter {
}
};
template <int Index, int VecSize>
template <int Index>
struct VecSizeGetter {
template <typename ArgsT>
static HOSTDEVICE void Apply(const std::vector<const DenseTensor *> &ins,
......@@ -569,8 +585,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
int vec_size = 4;
uint64_t addr = static_cast<uint64_t>(0);
ArgsT arg;
// The Arg VecSize=1 is to match the Unroller template.
Unroller<VecSizeGetter, 1, Arity>::step(ins, arg, &vec_size);
UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);
for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
addr = (addr | reinterpret_cast<uint64_t>((*iter)->data<OutT>()));
}
......@@ -580,73 +595,6 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
return vec_size;
}
template <typename InT,
typename OutT,
int VecSize,
typename Functor,
int Arity,
bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, Arity, Functor>(
result, args, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseConstant<InT, OutT, VecSize, 1, Functor>(result, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, Functor>(
result, args[0], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, Functor>(
result, args[0], args[1], func, read_lens);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, Functor>(
result, args[0], args[1], args[2], func);
}
};
namespace detail {
template <class F, class Tuple, std::size_t... Index>
// GCC/Clang need the decltype() return type
......@@ -802,7 +750,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
Unroller<InputSetter, VecSize, Arity>::step(ins, &ins_data);
UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_data);
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
}
......
......@@ -255,6 +255,18 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) {
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
*/
template <typename T, typename ArgsT, int Index, int NX>
__device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
#pragma unroll
for (int i = 0; i < NX; i++) {
std::get<Index>(dst[i]) = init_data;
}
}
/**
* @brief Read 1D data from global memory to register. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
......@@ -307,6 +319,23 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
/**
* @brief Read 1D data from global memory to register.
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/
template <typename T, int NX, int NY, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
......@@ -347,9 +376,8 @@ __device__ __forceinline__ void ReadData(T* dst,
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* ArgsT: The Type of dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
......@@ -369,7 +397,7 @@ template <typename T,
__device__ __forceinline__ void ReadData(ArgsT* dst,
const T* __restrict__ src,
int num,
int read_lens) {
int read_lens = 0) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
......@@ -743,7 +771,6 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
......@@ -788,6 +815,67 @@ __device__ __forceinline__ void ReadDataBc(
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
* The difference from the above function is that it supports different data
* types of inputs.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* ArgsT: The Type of dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
typename ArgsT,
int Index,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
ArgsT* dst,
const T* __restrict__ src,
uint32_t block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.rank) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
std::get<Index>(dst[nx]) = src[index_src];
}
}
/**
* @brief Initialize register with data index.
*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册