未验证 提交 3474e09c 编写于 作者: B Bo Zhang 提交者: GitHub

Support different dtypes of inputs for broadcast for dropout optimization (#52093)

* change judgement for DropoutGradGPUKernelDriver

* add UnrollerWithoutVecSize and after this Loaddata to be refined

* pass unittest

* use same unroller with XPU

* BroadcastWithInt64Index

* BroadcastDataLoader template partial specialization

* fix compile errs in ROCms

* PR comment
上级 fe053396
...@@ -31,20 +31,49 @@ namespace funcs { ...@@ -31,20 +31,49 @@ namespace funcs {
enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 };
template <typename InT, typename OutT, int Arity> template <int Index>
struct UseBroadcast {
template <typename ArgsT, typename Array1, typename Array2>
static HOSTDEVICE void Apply(
const std::vector<const DenseTensor *> &ins_tensor,
const ArgsT &args,
int64_t numel,
Array1 *ins_data,
Array2 *use_broadcast,
int *broadcast_num,
bool *all_elementwise) {
(*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data());
bool is_same_dim = ins_tensor[Index]->numel() == numel;
if (is_same_dim) {
(*use_broadcast)[Index] = false;
} else {
(*use_broadcast)[Index] = true;
(*broadcast_num)++;
}
*all_elementwise &= is_same_dim;
}
};
template <typename OutT, int Arity, typename Functor>
struct LoaderTypeClassifier { struct LoaderTypeClassifier {
public: public:
int64_t numel{0}; int64_t numel{0};
int vec_size{1}; int vec_size{4};
int broadcast_num{0}; int broadcast_num{0};
bool all_elementwise{true}; bool all_elementwise{true};
phi::Array<int, Arity> use_broadcast; phi::Array<bool, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data; phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
LoaderTypeClassifier() {} LoaderTypeClassifier() {}
LoaderTypeClassifier(const std::vector<const DenseTensor *> &ins, LoaderTypeClassifier(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs) { std::vector<DenseTensor *> *outs) {
using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple;
ArgsT arg;
uint64_t out_addr = reinterpret_cast<uint64_t>((*outs)[0]->data<OutT>()); uint64_t out_addr = reinterpret_cast<uint64_t>((*outs)[0]->data<OutT>());
UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);
for (auto i = 1; i < outs->size(); ++i) { for (auto i = 1; i < outs->size(); ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(), (*outs)[i]->dims(),
...@@ -56,165 +85,185 @@ struct LoaderTypeClassifier { ...@@ -56,165 +85,185 @@ struct LoaderTypeClassifier {
out_addr = out_addr =
(out_addr | reinterpret_cast<uint64_t>((*outs)[i]->data<OutT>())); (out_addr | reinterpret_cast<uint64_t>((*outs)[i]->data<OutT>()));
} }
int out_vec_size =
phi::GetVectorizedSize<OutT>(reinterpret_cast<OutT *>(out_addr));
uint64_t in_addr = static_cast<uint64_t>(0); vec_size = std::min(
vec_size,
phi::GetVectorizedSize<OutT>(reinterpret_cast<OutT *>(out_addr)));
numel = (*outs)[0]->numel(); numel = (*outs)[0]->numel();
for (int i = 0; i < Arity; ++i) { UnrollerWithoutVecSize<UseBroadcast, Arity>::step(ins,
auto in_data = ins[i]->data<InT>(); arg,
ins_data[i] = (const _ptr_ InT *)(in_data); numel,
&ins_data,
bool is_same_dim = ins[i]->numel() == numel; &use_broadcast,
if (is_same_dim) { &broadcast_num,
use_broadcast[i] = false; &all_elementwise);
in_addr = (in_addr | reinterpret_cast<uint64_t>(in_data));
} else {
use_broadcast[i] = true;
broadcast_num++;
}
all_elementwise &= is_same_dim;
}
int in_vec_size = std::min(
4, phi::GetVectorizedSize<InT>(reinterpret_cast<InT *>(in_addr)));
vec_size = std::min(out_vec_size, in_vec_size);
} }
}; };
#ifndef PADDLE_WITH_XPU_KP
// Common broadcast/elementwise Loader. // Common broadcast/elementwise Loader.
template <typename T, int VecSize, int Arity, bool IsBoundary, int LoadType> template <int Index, int VecSize, bool IsBoundary, int LoadType>
struct BroadcastDataLoader { struct BroadcastDataLoader {
__device__ __forceinline__ void operator()( template <typename Array1, typename Array2, typename Array3, typename ArgsT>
T args[Arity][VecSize], static __device__ __forceinline__ void Apply(const Array1 &ins,
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins, ArgsT *args,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs, const Array2 &configs,
const phi::Array<int, Arity> &use_broadcast, const Array3 &use_broadcast,
const int block_offset, const int block_offset,
const int num, const int num,
const uint32_t numel) { const uint32_t numel) {
#pragma unroll using Type = std::tuple_element_t<Index, ArgsT>;
for (int i = 0; i < Arity; ++i) { kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
kps::Init<T, VecSize>(args[i], static_cast<T>(1.0f));
if (use_broadcast[i]) { if (use_broadcast[Index]) {
kps::ReadDataBc<T, VecSize, 1, IsBoundary>( kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, VecSize); args,
} else { reinterpret_cast<const _ptr_ Type *>(ins[Index]),
kps::ReadData<T, VecSize, 1, IsBoundary>( block_offset,
args[i], ins[i] + block_offset, num, VecSize); configs[Index],
} numel,
VecSize);
}
// NOTE: If use if...else... with condition `use_broadcast[Index]` here,
// there will be some errs with clang12 while compiling in ROCm.
// When the compiler is upgraded, if...else... may be used.
if (!use_broadcast[Index]) {
kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
num,
VecSize);
} }
} }
}; };
/* BroadcastDataLoaders Partial specialization */
#ifndef PADDLE_WITH_XPU_KP
// Scalar elementwise Loader with consideration of IsBoundary. // Scalar elementwise Loader with consideration of IsBoundary.
template <typename T, int VecSize, int Arity> template <int Index, int VecSize>
struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> { struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
__device__ __forceinline__ void operator()( template <typename Array1, typename Array2, typename Array3, typename ArgsT>
T args[Arity][VecSize], static __device__ __forceinline__ void Apply(const Array1 &ins,
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins, ArgsT *args,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs, const Array2 &configs,
const phi::Array<int, Arity> &use_broadcast, const Array3 &use_broadcast,
const int block_offset, const int block_offset,
const int num, const int num,
const uint32_t numel) { const uint32_t numel) {
using Type = std::tuple_element_t<Index, ArgsT>;
int thread_offset = threadIdx.x * VecSize + block_offset; int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; ++i) { for (int idx = 0; idx < VecSize; ++idx) {
#pragma unroll std::get<Index>(args[idx]) = static_cast<Type>(1);
for (int idx = 0; idx < VecSize; ++idx) { int index = thread_offset + idx;
args[i][idx] = static_cast<T>(1); if (index < numel) {
int index = thread_offset + idx; std::get<Index>(args[idx]) =
if (index < numel) { reinterpret_cast<const _ptr_ Type *>(ins[Index])[index];
args[i][idx] = ins[i][index];
}
} }
} }
} }
}; };
// Vectorized elementwise Loader without consideration of IsBoundary. // Vectorized elementwise Loader without consideration of IsBoundary.
template <typename T, int VecSize, int Arity> template <int Index, int VecSize>
struct BroadcastDataLoader<T, VecSize, Arity, false, kElementwise> { struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
__device__ __forceinline__ void operator()( template <typename Array1, typename Array2, typename Array3, typename ArgsT>
T args[Arity][VecSize], static __device__ __forceinline__ void Apply(const Array1 &ins,
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins, ArgsT *args,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs, const Array2 &configs,
const phi::Array<int, Arity> &use_broadcast, const Array3 &use_broadcast,
const int block_offset, const int block_offset,
const int num, const int num,
const uint32_t numel) { const uint32_t numel) {
using VecType = phi::kps::details::VectorType<T, VecSize>; using Type = std::tuple_element_t<Index, ArgsT>;
VecType vec_temp[Arity]; using VecType = phi::kps::details::VectorType<Type, VecSize>;
VecType vec_temp;
int thread_offset = threadIdx.x + blockIdx.x * blockDim.x; int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
const VecType *__restrict__ vec_input =
reinterpret_cast<const VecType *__restrict__>(ins[Index]);
vec_temp = vec_input[thread_offset];
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; ++i) { for (int idx = 0; idx < VecSize; ++idx) {
const VecType *__restrict__ vec_input = std::get<Index>(args[idx]) = vec_temp.val[idx];
reinterpret_cast<const VecType *__restrict__>(ins[i]);
vec_temp[i] = vec_input[thread_offset];
#pragma unroll
for (int idx = 0; idx < VecSize; ++idx) {
args[i][idx] = vec_temp[i].val[idx];
}
} }
} }
}; };
// Common broadcast data loader. // Common broadcast data loader.
template <typename T, int VecSize, int Arity, bool IsBoundary> template <int Index, int VecSize, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, kBroadcast> { struct BroadcastDataLoader<Index, VecSize, IsBoundary, kBroadcast> {
__device__ __forceinline__ void operator()( template <typename Array1, typename Array2, typename Array3, typename ArgsT>
T args[Arity][VecSize], static __device__ __forceinline__ void Apply(const Array1 &ins,
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins, ArgsT *args,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs, const Array2 &configs,
const phi::Array<int, Arity> &use_broadcast, const Array3 &use_broadcast,
const int block_offset, const int block_offset,
const int num, const int num,
const uint32_t numel) { const uint32_t numel) {
uint32_t index_bc[Arity][VecSize]; using Type = std::tuple_element_t<Index, ArgsT>;
#pragma unroll uint32_t index_bc[VecSize];
for (int j = 0; j < Arity; ++j) {
#pragma unroll #pragma unroll
for (int k = 0; k < VecSize; ++k) { for (int k = 0; k < VecSize; ++k) {
index_bc[j][k] = 0; index_bc[k] = 0;
args[j][k] = static_cast<T>(1); std::get<Index>(args[k]) = static_cast<Type>(1);
}
} }
uint32_t thread_offset = block_offset + threadIdx.x * VecSize; uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll #pragma unroll
for (int k = 0; k < VecSize; ++k) { for (int k = 0; k < VecSize; ++k) {
uint32_t idx = thread_offset + k; uint32_t idx = thread_offset + k;
if (IsBoundary) { if (IsBoundary && idx == numel) {
if (idx == numel) break; break;
} }
#pragma unroll #pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) { for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].rank) break; if (i == configs[0].rank) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx); auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0]; idx = fast_divmoder.val[0];
#pragma unroll index_bc[k] += fast_divmoder.val[1] * configs[Index].strides[i];
for (int j = 0; j < Arity; ++j) {
index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
}
} }
} }
#pragma unroll #pragma unroll
for (int j = 0; j < Arity; ++j) { for (int k = 0; k < VecSize; ++k) {
#pragma unroll std::get<Index>(args[k]) =
for (int k = 0; k < VecSize; ++k) { reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[k]];
args[j][k] = ins[j][index_bc[j][k]];
}
} }
} }
}; };
#endif #endif
template <typename InT, // static broadcast unroller
typename OutT, template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
typename Func,
bool IsBoundary,
int LoadType,
int VecSize,
int End,
int Begin = 0>
struct BcUnroller {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {
Func<Begin, VecSize, IsBoundary, LoadType>::Apply(
std::forward<Args>(args)...);
BcUnroller<Func, IsBoundary, LoadType, VecSize, End, Begin + 1>::step(
args...);
}
};
template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
typename Func,
bool IsBoundary,
int LoadType,
int VecSize,
int End>
struct BcUnroller<Func, IsBoundary, LoadType, VecSize, End, End> {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {}
};
template <typename OutT,
typename Functor, typename Functor,
int Arity, int Arity,
int NumOuts, int NumOuts,
...@@ -222,59 +271,43 @@ template <typename InT, ...@@ -222,59 +271,43 @@ template <typename InT,
bool IsBoundary, bool IsBoundary,
int LoadType> int LoadType>
__device__ void VectorizedBroadcastKernelImpl( __device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins, const phi::Array<const _ptr_ char *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs, phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast, const phi::Array<bool, Arity> &use_broadcast,
const uint32_t numel, const uint32_t numel,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs, const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num, int num,
int block_offset, int block_offset,
int read_lens, int read_lens,
Functor func) { Functor func) {
__simd__ InT args[Arity][VecSize]; using Traits = phi::funcs::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple;
__simd__ ArgsT args[VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize]; __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#ifdef PADDLE_WITH_XPU_KP
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
if (use_broadcast[i]) {
kps::ReadDataBc<InT, VecSize, 1, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, read_lens);
} else {
kps::ReadData<InT, VecSize, 1, IsBoundary>(
args[i], ins[i] + block_offset, num, read_lens);
}
}
#else
BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, LoadType>()(
args, ins, configs, use_broadcast, block_offset, num, numel);
#endif
constexpr bool kCallElementwiseAny = BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
phi::funcs::FunctionTraits<Functor>::has_pointer_args; ins, args, configs, use_broadcast, block_offset, num, numel);
phi::funcs::ElementwisePrimitiveCaller<InT,
ConditionalT<OutT, NumOuts>, SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
VecSize, VecSize,
Functor, Functor,
Arity, ArgsT,
kCallElementwiseAny>()( Arity>()(func, args, result, read_lens);
func, args, result, read_lens);
phi::funcs:: phi::funcs::
ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()( ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num, read_lens); outs, result, block_offset, num, read_lens);
} }
template <typename Functor, template <typename Functor,
typename InT,
typename OutT, typename OutT,
int Arity, int Arity,
int NumOuts, int NumOuts,
int VecSize, int VecSize,
int LoadType> int LoadType>
__global__ void VectorizedBroadcastKernel( __global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins, phi::Array<const _ptr_ char *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs, phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast, phi::Array<bool, Arity> use_broadcast,
uint32_t numel, uint32_t numel,
phi::Array<kps::details::BroadcastConfig, Arity> configs, phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset, int main_offset,
...@@ -285,8 +318,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -285,8 +318,7 @@ __global__ void VectorizedBroadcastKernel(
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
for (; block_offset < main_offset; block_offset += stride) { for (; block_offset < main_offset; block_offset += stride) {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<OutT,
OutT,
Functor, Functor,
Arity, Arity,
NumOuts, NumOuts,
...@@ -304,8 +336,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -304,8 +336,7 @@ __global__ void VectorizedBroadcastKernel(
} }
int num = numel - block_offset; int num = numel - block_offset;
if (num > 0) { if (num > 0) {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<OutT,
OutT,
Functor, Functor,
Arity, Arity,
NumOuts, NumOuts,
...@@ -324,8 +355,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -324,8 +355,7 @@ __global__ void VectorizedBroadcastKernel(
#else #else
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
if (block_offset < main_offset) { if (block_offset < main_offset) {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<OutT,
OutT,
Functor, Functor,
Arity, Arity,
NumOuts, NumOuts,
...@@ -341,8 +371,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -341,8 +371,7 @@ __global__ void VectorizedBroadcastKernel(
read_lens, read_lens,
func); func);
} else { } else {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<OutT,
OutT,
Functor, Functor,
Arity, Arity,
NumOuts, NumOuts,
...@@ -361,19 +390,14 @@ __global__ void VectorizedBroadcastKernel( ...@@ -361,19 +390,14 @@ __global__ void VectorizedBroadcastKernel(
#endif #endif
} }
template <typename InT, template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
typename OutT,
typename Func,
int Arity,
int NumOuts,
int VecSize>
void LaunchBroadcastKernel( void LaunchBroadcastKernel(
const KPDevice &ctx, const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Func func, Functor func,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs, const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const LoaderTypeClassifier<InT, OutT, Arity> &loader_classifier) { const LoaderTypeClassifier<OutT, Arity, Functor> &loader_classifier) {
phi::Array<_ptr_ OutT *, NumOuts> outs_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) { for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i])); outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
...@@ -388,7 +412,7 @@ void LaunchBroadcastKernel( ...@@ -388,7 +412,7 @@ void LaunchBroadcastKernel(
int main_offset = (numel / (read_lens * threads)) * read_lens * threads; int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads); int tail_tid = numel % (read_lens * threads);
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, false> VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, false>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data, <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data, outs_data,
loader_classifier.use_broadcast, loader_classifier.use_broadcast,
...@@ -409,8 +433,7 @@ void LaunchBroadcastKernel( ...@@ -409,8 +433,7 @@ void LaunchBroadcastKernel(
int tail_tid = numel % (VecSize * threads); int tail_tid = numel % (VecSize * threads);
if (loader_classifier.all_elementwise) { if (loader_classifier.all_elementwise) {
VectorizedBroadcastKernel<Func, VectorizedBroadcastKernel<Functor,
InT,
OutT, OutT,
Arity, Arity,
NumOuts, NumOuts,
...@@ -427,7 +450,7 @@ void LaunchBroadcastKernel( ...@@ -427,7 +450,7 @@ void LaunchBroadcastKernel(
func); func);
} else if (loader_classifier.broadcast_num > (Arity >> 1)) { } else if (loader_classifier.broadcast_num > (Arity >> 1)) {
constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed; constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed;
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, type_> VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, type_>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data, <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data, outs_data,
loader_classifier.use_broadcast, loader_classifier.use_broadcast,
...@@ -438,7 +461,7 @@ void LaunchBroadcastKernel( ...@@ -438,7 +461,7 @@ void LaunchBroadcastKernel(
VecSize, VecSize,
func); func);
} else { } else {
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, kMixed> VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, kMixed>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data, <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data, outs_data,
loader_classifier.use_broadcast, loader_classifier.use_broadcast,
...@@ -471,94 +494,49 @@ HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx( ...@@ -471,94 +494,49 @@ HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
return dst_idx; return dst_idx;
} }
template <typename T, int VecSize, bool IsBoundary> template <int N>
HOSTDEVICE static void ReadVecDataWithInt64Index( struct MaxWithOne {
const T *in, static constexpr auto kValue = (N >= 1 ? N : 1);
int64_t idx, };
bool need_broadcast,
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides, template <int Index, int VecSize>
const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides, struct ReadVecDataWithInt64Index {
int rank, template <typename Array1, typename Array2, typename Array3, typename ArgsT>
int n, static __device__ __forceinline__ void Apply(
phi::AlignedVector<T, VecSize> *out) { const Array1 &in,
if (IsBoundary) { ArgsT *args,
for (int i = 0; i < n; ++i) { int64_t idx,
(*out)[i] = const Array2 &need_broadcast,
in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)]; const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
} const Array3 &dst_strides,
} else { int rank,
if (!need_broadcast) { bool is_boundary) {
phi::Load<T, VecSize>(in + idx, out); using Type = std::tuple_element_t<Index, ArgsT>;
} else { if (is_boundary) {
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
(*out)[i] = std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)]; idx + i, src_strides, dst_strides[Index], rank)];
}
} else {
if (!need_broadcast[Index]) {
kps::ReadData<Type, VecSize, 1, ArgsT, Index, false>(
args, reinterpret_cast<const _ptr_ Type *>(in[Index]) + idx, 1);
} else {
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
idx + i, src_strides, dst_strides[Index], rank)];
}
} }
} }
} }
}
template <typename InT,
typename OutT,
typename Functor,
int VecSize,
int NumIns>
struct ApplyFunctorWithInt64IndexHelper {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i);
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 0> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(functor());
}
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 1> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(functor(ins_vec[0][i]));
}
};
template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 2> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(functor(ins_vec[0][i], ins_vec[1][i]));
}
}; };
template <typename InT, typename OutT, typename Functor, int VecSize> template <typename OutT, typename Functor, int VecSize, int NumIns>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 3> {
HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
Functor functor,
int i) {
return static_cast<OutT>(
functor(ins_vec[0][i], ins_vec[1][i], ins_vec[2][i]));
}
};
template <int N>
struct MaxWithOne {
static constexpr auto kValue = (N >= 1 ? N : 1);
};
template <typename InT,
typename OutT,
typename Functor,
int VecSize,
int NumIns>
__global__ void BroadcastKernelWithInt64Index( __global__ void BroadcastKernelWithInt64Index(
phi::Array<const InT *, MaxWithOne<NumIns>::kValue> ins, const phi::Array<const _ptr_ char *__restrict__, MaxWithOne<NumIns>::kValue>
&ins,
OutT *out, OutT *out,
phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>, phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
MaxWithOne<NumIns>::kValue> ins_strides, MaxWithOne<NumIns>::kValue> ins_strides,
...@@ -572,70 +550,34 @@ __global__ void BroadcastKernelWithInt64Index( ...@@ -572,70 +550,34 @@ __global__ void BroadcastKernelWithInt64Index(
int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize; int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
int64_t limit = numel - VecSize; int64_t limit = numel - VecSize;
phi::Array<phi::AlignedVector<InT, VecSize>, MaxWithOne<NumIns>::kValue> using Traits = phi::funcs::FunctionTraits<Functor>;
ins_vec; using ArgsT = typename Traits::ArgsTuple;
ArgsT args[VecSize];
phi::AlignedVector<OutT, VecSize> out_vec; phi::AlignedVector<OutT, VecSize> out_vec;
for (; idx <= limit; idx += stride) { for (; idx <= limit; idx += stride) {
#pragma unroll Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
for (int i = 0; i < NumIns; ++i) { ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, false);
ReadVecDataWithInt64Index<InT, VecSize, false>(ins[i],
idx,
need_broadcasts[i],
out_strides,
ins_strides[i],
rank,
VecSize,
&ins_vec[i]);
}
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
out_vec[i] = ApplyFunctorWithInt64IndexHelper<InT, out_vec[i] = static_cast<OutT>(Apply(functor, args[i]));
OutT,
Functor,
VecSize,
NumIns>::Run(ins_vec.Get(),
functor,
i);
} }
phi::Store<OutT, VecSize>(out_vec, out + idx); phi::Store<OutT, VecSize>(out_vec, out + idx);
} }
if (idx < numel) { if (idx < numel) {
int remain = numel - idx; // remain is always less than VecSize, therefore int remain = numel - idx; // remain is always less than VecSize, therefore
// `int` is enough here // `int` is enough here
#pragma unroll Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
for (int i = 0; i < NumIns; ++i) { ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, true);
ReadVecDataWithInt64Index<InT, VecSize, true>(ins[i],
idx,
need_broadcasts[i],
out_strides,
ins_strides[i],
rank,
remain,
&ins_vec[i]);
}
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
out[idx + i] = out_vec[idx + i] = static_cast<OutT>(Apply(functor, args[i]));
ApplyFunctorWithInt64IndexHelper<InT,
OutT,
Functor,
VecSize,
NumIns>::Run(ins_vec.Get(),
functor,
i);
} }
} }
} }
template <typename InT, template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper { struct LaunchBroadcastKernelWithInt64IndexHelper {
static void Run(const KPDevice &ctx, static void Run(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
...@@ -647,9 +589,8 @@ struct LaunchBroadcastKernelWithInt64IndexHelper { ...@@ -647,9 +589,8 @@ struct LaunchBroadcastKernelWithInt64IndexHelper {
} }
}; };
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize> template <typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<InT, struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
OutT,
Functor, Functor,
Arity, Arity,
/*NumOuts=*/1, /*NumOuts=*/1,
...@@ -659,10 +600,9 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT, ...@@ -659,10 +600,9 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int axis, int axis,
Functor functor) { Functor functor) {
phi::Array<const InT *, MaxWithOne<Arity>::kValue> ins_ptrs; phi::Array<const _ptr_ char *__restrict__, MaxWithOne<Arity>::kValue>
for (int i = 0; i < Arity; ++i) { ins_ptrs;
ins_ptrs[i] = ins[i]->data<InT>(); UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_ptrs);
}
auto *out_tensor = (*outs)[0]; auto *out_tensor = (*outs)[0];
auto *out_ptr = ctx.Alloc<OutT>(out_tensor); auto *out_ptr = ctx.Alloc<OutT>(out_tensor);
...@@ -734,7 +674,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT, ...@@ -734,7 +674,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
auto gpu_config = auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
BroadcastKernelWithInt64Index<InT, OutT, Functor, VecSize, Arity> BroadcastKernelWithInt64Index<OutT, Functor, VecSize, Arity>
<<<gpu_config.block_per_grid, <<<gpu_config.block_per_grid,
gpu_config.thread_per_block, gpu_config.thread_per_block,
0, 0,
...@@ -844,9 +784,9 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT, ...@@ -844,9 +784,9 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
#endif #endif
template <ElementwiseType ET, template <ElementwiseType ET,
typename InT,
typename OutT, typename OutT,
typename Functor, typename Functor,
int kArity,
int NumOuts = 1> int NumOuts = 1>
void BroadcastKernelForDifferentVecSize( void BroadcastKernelForDifferentVecSize(
const KPDevice &ctx, const KPDevice &ctx,
...@@ -854,47 +794,17 @@ void BroadcastKernelForDifferentVecSize( ...@@ -854,47 +794,17 @@ void BroadcastKernelForDifferentVecSize(
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int axis, int axis,
Functor func) { Functor func) {
using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(
ins.size(),
kArity,
phi::errors::InvalidArgument("The number of inputs is expected to be "
"equal to the "
"arity of functor. But received: the "
"number of inputs "
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_LE(
kArity,
3,
phi::errors::InvalidArgument("Currently only broadcast of ternary is "
"supported "
"and verified, but received %d.",
kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
phi::errors::InvalidArgument("Number of outputs shall equal to number "
"of functions, "
"but number of outputs is %d, of "
"functions is %d.",
outs->size(),
NumOuts));
#ifndef PADDLE_WITH_XPU_KP #ifndef PADDLE_WITH_XPU_KP
constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3); constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
bool use_int64_index_kernel = bool use_int64_index_kernel =
kEnabledInt64IndexKernel && kEnabledInt64IndexKernel &&
(*outs)[0]->numel() >= std::numeric_limits<int32_t>::max(); (*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
if (use_int64_index_kernel) { if (use_int64_index_kernel) {
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(ins, outs); auto loader_classifier =
LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
switch (loader_classifier.vec_size) { switch (loader_classifier.vec_size) {
case VecSizeL: { case VecSizeL: {
LaunchBroadcastKernelWithInt64IndexHelper<InT, LaunchBroadcastKernelWithInt64IndexHelper<OutT,
OutT,
Functor, Functor,
kArity, kArity,
NumOuts, NumOuts,
...@@ -906,8 +816,7 @@ void BroadcastKernelForDifferentVecSize( ...@@ -906,8 +816,7 @@ void BroadcastKernelForDifferentVecSize(
break; break;
} }
case VecSizeM: { case VecSizeM: {
LaunchBroadcastKernelWithInt64IndexHelper<InT, LaunchBroadcastKernelWithInt64IndexHelper<OutT,
OutT,
Functor, Functor,
kArity, kArity,
NumOuts, NumOuts,
...@@ -919,8 +828,7 @@ void BroadcastKernelForDifferentVecSize( ...@@ -919,8 +828,7 @@ void BroadcastKernelForDifferentVecSize(
break; break;
} }
case VecSizeS: { case VecSizeS: {
LaunchBroadcastKernelWithInt64IndexHelper<InT, LaunchBroadcastKernelWithInt64IndexHelper<OutT,
OutT,
Functor, Functor,
kArity, kArity,
NumOuts, NumOuts,
...@@ -949,7 +857,7 @@ void BroadcastKernelForDifferentVecSize( ...@@ -949,7 +857,7 @@ void BroadcastKernelForDifferentVecSize(
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"XPU only support inputs is 2, but received %d", ins.size())); "XPU only support inputs is 2, but received %d", ins.size()));
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(); auto loader_classifier = LoaderTypeClassifier<OutT, kArity, Functor>();
const auto dims_simplifier = const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) { if (VLOG_IS_ON(6)) {
...@@ -968,7 +876,8 @@ void BroadcastKernelForDifferentVecSize( ...@@ -968,7 +876,8 @@ void BroadcastKernelForDifferentVecSize(
bool is_optimize = configs[0].cmp_type != type; bool is_optimize = configs[0].cmp_type != type;
int vec_size = is_optimize ? VecSizeL : VecSizeM; int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else #else
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(ins, outs); auto loader_classifier =
LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
if (!loader_classifier.all_elementwise) { if (!loader_classifier.all_elementwise) {
const auto dims_simplifier = const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
...@@ -991,17 +900,17 @@ void BroadcastKernelForDifferentVecSize( ...@@ -991,17 +900,17 @@ void BroadcastKernelForDifferentVecSize(
#endif #endif
switch (loader_classifier.vec_size) { switch (loader_classifier.vec_size) {
case VecSizeL: { case VecSizeL: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeL>( LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeL>(
ctx, ins, outs, func, configs, loader_classifier); ctx, ins, outs, func, configs, loader_classifier);
break; break;
} }
case VecSizeM: { case VecSizeM: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeM>( LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeM>(
ctx, ins, outs, func, configs, loader_classifier); ctx, ins, outs, func, configs, loader_classifier);
break; break;
} }
case VecSizeS: { case VecSizeS: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeS>( LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeS>(
ctx, ins, outs, func, configs, loader_classifier); ctx, ins, outs, func, configs, loader_classifier);
break; break;
} }
...@@ -1025,6 +934,28 @@ void BroadcastKernel(const KPDevice &ctx, ...@@ -1025,6 +934,28 @@ void BroadcastKernel(const KPDevice &ctx,
Functor func) { Functor func) {
// When there are multiple inputs, the outputs's rank should be equal the // When there are multiple inputs, the outputs's rank should be equal the
// maximum rank of all inputs. // maximum rank of all inputs.
using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity = Traits::arity;
PADDLE_ENFORCE_EQ(
ins.size(),
kArity,
phi::errors::InvalidArgument("The number of inputs is expected to be "
"equal to the "
"arity of functor. But received: the "
"number of inputs "
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(
outs->size(),
NumOuts,
phi::errors::InvalidArgument("Number of outputs shall equal to number "
"of functions, "
"but number of outputs is %d, of "
"functions is %d.",
outs->size(),
NumOuts));
int max_rank = 0; int max_rank = 0;
int min_rank = phi::DDim::kMaxRank; int min_rank = phi::DDim::kMaxRank;
for (auto *in : ins) { for (auto *in : ins) {
...@@ -1037,7 +968,7 @@ void BroadcastKernel(const KPDevice &ctx, ...@@ -1037,7 +968,7 @@ void BroadcastKernel(const KPDevice &ctx,
max_rank = std::max(max_rank, (*outs)[0]->dims().size()); max_rank = std::max(max_rank, (*outs)[0]->dims().size());
} }
axis = axis == -1 ? max_rank - min_rank : axis; axis = axis == -1 ? max_rank - min_rank : axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>( BroadcastKernelForDifferentVecSize<ET, OutT, Functor, kArity, NumOuts>(
ctx, ins, outs, axis, func); ctx, ins, outs, axis, func);
} }
...@@ -1051,6 +982,7 @@ void ElementwiseCompute(const GPUContext &dev_ctx, ...@@ -1051,6 +982,7 @@ void ElementwiseCompute(const GPUContext &dev_ctx,
std::vector<const DenseTensor *> ins = {&x, &y}; std::vector<const DenseTensor *> ins = {&x, &y};
std::vector<DenseTensor *> outs = {z}; std::vector<DenseTensor *> outs = {z};
dev_ctx.template Alloc<OutType>(z); dev_ctx.template Alloc<OutType>(z);
BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>( BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
dev_ctx, ins, &outs, axis, func); dev_ctx, ins, &outs, axis, func);
} }
......
...@@ -459,41 +459,43 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, ...@@ -459,41 +459,43 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
// y = factor * x // y = factor * x
ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor); ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor);
} else { } else {
phi::DenseTensor broadcasted_mask; if (upscale_in_train && dropout_prob == 1.0f) {
if (is_dropout_nd) {
broadcasted_mask.Resize(grad_y.dims());
dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
}
std::vector<const phi::DenseTensor*> ins = {
&grad_y, is_dropout_nd ? &broadcasted_mask : &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
if (upscale_in_train) {
if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T)); hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#else #else
cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T)); cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
#endif #endif
} else {
MT factor = upscale_in_train
? static_cast<MT>(1.0f / (1.0f - dropout_prob))
: static_cast<MT>(1.0f);
if (is_dropout_nd) {
phi::DenseTensor broadcasted_mask;
broadcasted_mask.Resize(grad_y.dims());
dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);
std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
uint8_t,
uint8_t>(dev_ctx,
broadcast_ins,
&broadcast_outs,
-1,
kps::IdentityFunctor<uint8_t>());
std::vector<const phi::DenseTensor*> ins = {&grad_y, &broadcasted_mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} else { } else {
MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob)); std::vector<const phi::DenseTensor*> ins = {&grad_y, &mask};
std::vector<phi::DenseTensor*> outs = {grad_x};
phi::funcs::ElementwiseKernel<T>( phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor)); dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} }
} else {
MT factor = static_cast<MT>(1.0f);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
} }
} }
} }
......
...@@ -35,7 +35,7 @@ namespace kps = phi::kps; ...@@ -35,7 +35,7 @@ namespace kps = phi::kps;
namespace phi { namespace phi {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type /* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
for supporting multiple-output feature in elementwise system.*/ for supporting multiple-output feature in elementwise system.*/
template <class T, int Num> template <class T, int Num>
...@@ -508,15 +508,31 @@ struct Unroller<Func, VecSize, End, End> { ...@@ -508,15 +508,31 @@ struct Unroller<Func, VecSize, End, End> {
static HOSTDEVICE inline void step(Args &&...args) {} static HOSTDEVICE inline void step(Args &&...args) {}
}; };
// static unroller without VecSize for broadcast
template <template <int Index> typename Func, int End, int Begin = 0>
struct UnrollerWithoutVecSize {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {
Func<Begin>::Apply(std::forward<Args>(args)...);
UnrollerWithoutVecSize<Func, End, Begin + 1>::step(args...);
}
};
template <template <int Index> typename Func, int End>
struct UnrollerWithoutVecSize<Func, End, End> {
template <typename... Args>
static HOSTDEVICE inline void step(Args &&...args) {}
};
template <int Index, int VecSize> template <int Index, int VecSize>
struct Loader { struct Loader {
template <typename Array, typename ArgsT> template <typename Array, typename ArgsT>
static __device__ void Apply(const Array &in, static __device__ __forceinline__ void Apply(const Array &in,
ArgsT *args, ArgsT *args,
kps::IndexType offset, kps::IndexType offset,
int num, int num,
int read_lens, int read_lens,
bool is_boundary) { bool is_boundary) {
using Type = std::tuple_element_t<Index, ArgsT>; using Type = std::tuple_element_t<Index, ArgsT>;
kps::Init<Type, ArgsT, Index, VecSize>( kps::Init<Type, ArgsT, Index, VecSize>(
args, static_cast<Type>(1.0f), read_lens); args, static_cast<Type>(1.0f), read_lens);
...@@ -536,7 +552,7 @@ struct Loader { ...@@ -536,7 +552,7 @@ struct Loader {
} }
}; };
template <int Index, int VecSize> template <int Index>
struct InputSetter { struct InputSetter {
template <typename Array> template <typename Array>
static HOSTDEVICE void Apply( static HOSTDEVICE void Apply(
...@@ -545,7 +561,7 @@ struct InputSetter { ...@@ -545,7 +561,7 @@ struct InputSetter {
} }
}; };
template <int Index, int VecSize> template <int Index>
struct VecSizeGetter { struct VecSizeGetter {
template <typename ArgsT> template <typename ArgsT>
static HOSTDEVICE void Apply(const std::vector<const DenseTensor *> &ins, static HOSTDEVICE void Apply(const std::vector<const DenseTensor *> &ins,
...@@ -569,8 +585,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins, ...@@ -569,8 +585,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
int vec_size = 4; int vec_size = 4;
uint64_t addr = static_cast<uint64_t>(0); uint64_t addr = static_cast<uint64_t>(0);
ArgsT arg; ArgsT arg;
// The Arg VecSize=1 is to match the Unroller template. UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);
Unroller<VecSizeGetter, 1, Arity>::step(ins, arg, &vec_size);
for (auto iter = outs.begin(); iter != outs.end(); ++iter) { for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
addr = (addr | reinterpret_cast<uint64_t>((*iter)->data<OutT>())); addr = (addr | reinterpret_cast<uint64_t>((*iter)->data<OutT>()));
} }
...@@ -580,73 +595,6 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins, ...@@ -580,73 +595,6 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
return vec_size; return vec_size;
} }
template <typename InT,
typename OutT,
int VecSize,
typename Functor,
int Arity,
bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, Arity, Functor>(
result, args, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseConstant<InT, OutT, VecSize, 1, Functor>(result, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, Functor>(
result, args[0], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, Functor>(
result, args[0], args[1], func, read_lens);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result,
int read_lens) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, Functor>(
result, args[0], args[1], args[2], func);
}
};
namespace detail { namespace detail {
template <class F, class Tuple, std::size_t... Index> template <class F, class Tuple, std::size_t... Index>
// GCC/Clang need the decltype() return type // GCC/Clang need the decltype() return type
...@@ -802,7 +750,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx, ...@@ -802,7 +750,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
phi::Array<const _ptr_ char *__restrict__, Arity> ins_data; phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data;
Unroller<InputSetter, VecSize, Arity>::step(ins, &ins_data); UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_data);
for (int i = 0; i < NumOuts; ++i) { for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i])); outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
} }
......
...@@ -255,6 +255,18 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) { ...@@ -255,6 +255,18 @@ __device__ __forceinline__ void Init(ArgsT* dst, T init_data, int read_lens) {
} }
} }
/**
* The difference from the above function is that
* it supports different data types of inputs.
*/
template <typename T, typename ArgsT, int Index, int NX>
__device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
#pragma unroll
for (int i = 0; i < NX; i++) {
std::get<Index>(dst[i]) = init_data;
}
}
/** /**
* @brief Read 1D data from global memory to register. When IsBoundary = true * @brief Read 1D data from global memory to register. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
...@@ -307,6 +319,23 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -307,6 +319,23 @@ __device__ __forceinline__ void ReadData(T* dst,
} }
} }
/**
* @brief Read 1D data from global memory to register.
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/
template <typename T, int NX, int NY, bool IsBoundary = false> template <typename T, int NX, int NY, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst, __device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src, const T* __restrict__ src,
...@@ -347,9 +376,8 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -347,9 +376,8 @@ __device__ __forceinline__ void ReadData(T* dst,
* T: The type of data. * T: The type of data.
* NX: Each thread load NX data from global memory continuously. * NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported. * NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args> * ArgsT: The Type of dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst. * Index: The index of data stored in dst.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory. * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than * When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
...@@ -369,7 +397,7 @@ template <typename T, ...@@ -369,7 +397,7 @@ template <typename T,
__device__ __forceinline__ void ReadData(ArgsT* dst, __device__ __forceinline__ void ReadData(ArgsT* dst,
const T* __restrict__ src, const T* __restrict__ src,
int num, int num,
int read_lens) { int read_lens = 0) {
if (IsBoundary) { // blockDim.x * NX > num if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
#pragma unroll #pragma unroll
...@@ -743,7 +771,6 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { ...@@ -743,7 +771,6 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* NX: The number of data continuously loaded by each thread. * NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported. * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported. * threadIdx.x is used as the thread index. Currently only GPU was supported.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds * IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than * judgment. When the number of data processed by the block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
...@@ -788,6 +815,67 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -788,6 +815,67 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
/**
* @brief Read 1D data from global memory to register with broadcast form.
* The difference from the above function is that it supports different data
* types of inputs.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* ArgsT: The Type of dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, blockDim.x * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
typename ArgsT,
int Index,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
ArgsT* dst,
const T* __restrict__ src,
uint32_t block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.rank) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
std::get<Index>(dst[nx]) = src[index_src];
}
}
/** /**
* @brief Initialize register with data index. * @brief Initialize register with data index.
* *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册