未验证 提交 835a1888 编写于 作者: N niuliling123 提交者: GitHub

cherry-pick 42645 (#43205)

删除Broadcast function中rank例化以及Elementwise调用,降低编译时间。
从develop分支中的#42645 PR修改而来,由于develop分支与release分支相差较大,无法实现cherry-pick,因此针对release2.3重新提交PR.
Broadcast中关于rank的例化会导致底层模板展开较多,造成reduce_sum_grad_kernel.cu.o文件体积过大,修改后可以降低.o体积及编译时间
上级 40a7e0ad
......@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid;
......@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(
// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
......@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads);
phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast;
use_broadcast[0] = false;
......@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);
auto func = AddFunctor<T>();
auto stream = ctx.stream();
......
......@@ -185,19 +185,19 @@ struct DimensionsTransform {
}
};
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
const kps::details::BroadcastConfig &config,
int numel,
int num,
int need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
kps::ReadDataBc<T, VecSize, 1, 1, IsBoundary>(
dst, src, block_offset, config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
......@@ -210,14 +210,13 @@ template <typename InT,
int Arity,
int NumOuts,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
uint32_t numel,
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
Functor func) {
......@@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl(
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
}
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
......@@ -254,14 +253,13 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast,
uint32_t numel,
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset,
int tail_tid,
Functor func) {
......@@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
......@@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, num, block_offset, func);
}
......@@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
......@@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
......@@ -334,15 +328,14 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
void LaunchBroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func,
DimensionsTransform merge_dims) {
int numel = (*outs)[0]->numel();
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
phi::Array<kps::details::BroadcastConfig, Arity> configs;
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
......@@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
......@@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#else
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
......@@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, 0, stream>>>(
ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void BroadcastKernelForDifferentDimSize(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, outs, func, merge_dims); \
} break;
switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.",
merge_dims.dim_size,
phi::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}
template <ElementwiseType ET,
typename InT,
typename OutT,
......@@ -506,33 +458,22 @@ void BroadcastKernelForDifferentVecSize(
: in_vec_size;
}
int vec_size = std::min(out_vec_size, in_vec_size);
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
switch (vec_size) {
case 4: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
ctx, ins, outs, func, merge_dims);
break;
}
case 2: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
ctx, ins, outs, func, merge_dims);
break;
}
case 1: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
ctx, ins, outs, func, merge_dims);
break;
}
default: {
......
......@@ -82,10 +82,10 @@ struct FastDivMod {
* index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0].
*/
template <int kDims>
struct BroadcastConfig {
FastDivMod divmoders[kDims];
FastDivMod divmoders[phi::DDim::kMaxRank];
uint32_t strides[phi::DDim::kMaxRank];
int kDims;
HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
......@@ -109,7 +109,7 @@ struct BroadcastConfig {
std::multiplies<int64_t>())
: strides_in[i];
}
kDims = dim_size;
memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
}
......@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
for (int i = 0; i < NX; i++) {
dst[i] = init_data;
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
......@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
dst[idx] = src[thread_offset + idx];
}
}
} else { // blockDim,x * NX < num
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
}
}
}
}
/**
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
......@@ -396,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
const details::BroadcastConfig& config,
int total_num_output,
int stride_nx,
int stride_ny) {
......@@ -425,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc(
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
......@@ -576,6 +612,36 @@ __device__ __forceinline__ void WriteData(T* dst,
}
}
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst,
T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) {
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((thread_offset + idx) < num) {
dst[thread_offset + idx] = src[idx];
}
}
} else {
// Vector type
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[thread_offset + idx] = vec_temp[idx];
}
}
}
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
......@@ -715,18 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
int total_num_output) {
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
......@@ -740,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc(
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
......
......@@ -32,11 +32,11 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* must be [dim1, dim0].
*/
#pragma pack(4)
template <int kDims>
struct BroadcastConfig {
int strides_in[phi::DDim::kMaxRank];
int strides_out[phi::DDim::kMaxRank];
int in_dim[phi::DDim::kMaxRank];
int kDims;
HOSTDEVICE BroadcastConfig() {}
......@@ -58,6 +58,7 @@ struct BroadcastConfig {
dim_tmp[i] = in_dims[i];
}
kDims = dim_size;
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
......@@ -328,16 +329,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
details::BroadcastConfig config,
int total_num_output,
int stride_nx,
int stride_ny) {
......@@ -643,18 +639,12 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output) {
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig& config,
int total_num_output) {
int thread_offset = block_offset + core_id() * NX;
int index_src = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册