未验证 提交 835a1888 编写于 作者: N niuliling123 提交者: GitHub

cherry-pick 42645 (#43205)

删除Broadcast function中rank例化以及Elementwise调用,降低编译时间。
从develop分支中的#42645 PR修改而来,由于develop分支与release分支相差较大,无法实现cherry-pick,因此针对release2.3重新提交PR.
Broadcast中关于rank的例化会导致底层模板展开较多,造成reduce_sum_grad_kernel.cu.o文件体积过大,修改后可以降低.o体积及编译时间
上级 40a7e0ad
...@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize, ...@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary( __global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out, const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel, phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM> phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
configlists,
int main_tid, int tail_tid, Functor func) { int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize; int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid; int num = tail_tid;
...@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary( ...@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(
// load in0 // load in0
if (use_broadcast[0]) { if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>( kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel); arg0, in0, fix, configlists[0], numel);
} else { } else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num); kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
} }
// load in1 // load in1
if (use_broadcast[1]) { if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>( kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel); arg1, in1, fix, configlists[1], numel);
} else { } else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num); kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
...@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, ...@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads); int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads); int tail_tid = numel % (data_per_thread * vec_size * threads);
phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists; phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast; phi::Array<bool, MAX_INPUT_NUM> use_broadcast;
use_broadcast[0] = false; use_broadcast[0] = false;
...@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, ...@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig. // Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1}; std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m}; std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2); configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);
auto func = AddFunctor<T>(); auto func = AddFunctor<T>();
auto stream = ctx.stream(); auto stream = ctx.stream();
......
...@@ -185,19 +185,19 @@ struct DimensionsTransform { ...@@ -185,19 +185,19 @@ struct DimensionsTransform {
} }
}; };
template <typename T, int VecSize, int Rank, bool IsBoundary = false> template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData( __device__ __forceinline__ void LoadData(
T *dst, T *dst,
const _ptr_ T *src, const _ptr_ T *src,
uint32_t block_offset, uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config, const kps::details::BroadcastConfig &config,
int numel, int numel,
int num, int num,
int need_broadcast) { int need_broadcast) {
// numel : whole num of output // numel : whole num of output
// num: how many data will be deal with in this time // num: how many data will be deal with in this time
if (need_broadcast) { if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>( kps::ReadDataBc<T, VecSize, 1, 1, IsBoundary>(
dst, src, block_offset, config, numel); dst, src, block_offset, config, numel);
} else { } else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num); kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
...@@ -210,14 +210,13 @@ template <typename InT, ...@@ -210,14 +210,13 @@ template <typename InT,
int Arity, int Arity,
int NumOuts, int NumOuts,
int VecSize, int VecSize,
int Rank,
bool IsBoundary = false> bool IsBoundary = false>
__device__ void VectorizedBroadcastKernelImpl( __device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins, const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs, phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast, const phi::Array<int, Arity> &use_broadcast,
uint32_t numel, uint32_t numel,
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs, const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num, int num,
int block_offset, int block_offset,
Functor func) { Functor func) {
...@@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl(
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i], LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i], ins[i],
block_offset, block_offset,
configs[i], configs[i],
numel, numel,
num, num,
use_broadcast[i]); use_broadcast[i]);
} }
constexpr bool kCallElementwiseAny = constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args; paddle::platform::FunctionTraits<Functor>::has_pointer_args;
...@@ -254,14 +253,13 @@ template <typename InT, ...@@ -254,14 +253,13 @@ template <typename InT,
typename Functor, typename Functor,
int Arity, int Arity,
int NumOuts, int NumOuts,
int VecSize, int VecSize>
int Rank>
__global__ void VectorizedBroadcastKernel( __global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins, phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs, phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast, phi::Array<int, Arity> use_broadcast,
uint32_t numel, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs, phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset, int main_offset,
int tail_tid, int tail_tid,
Functor func) { Functor func) {
...@@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel( ...@@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel(
Arity, Arity,
NumOuts, NumOuts,
VecSize, VecSize,
Rank,
false>(ins, false>(ins,
outs, outs,
use_broadcast, use_broadcast,
...@@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel( ...@@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel(
Arity, Arity,
NumOuts, NumOuts,
VecSize, VecSize,
Rank,
true>( true>(
ins, outs, use_broadcast, numel, configs, num, block_offset, func); ins, outs, use_broadcast, numel, configs, num, block_offset, func);
} }
...@@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel( ...@@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel(
Arity, Arity,
NumOuts, NumOuts,
VecSize, VecSize,
Rank,
false>(ins, false>(ins,
outs, outs,
use_broadcast, use_broadcast,
...@@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel( ...@@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel(
Arity, Arity,
NumOuts, NumOuts,
VecSize, VecSize,
Rank,
true>( true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
} }
...@@ -334,15 +328,14 @@ template <typename InT, ...@@ -334,15 +328,14 @@ template <typename InT,
typename Functor, typename Functor,
int Arity, int Arity,
int NumOuts, int NumOuts,
int VecSize, int VecSize>
int Rank>
void LaunchBroadcastKernel(const KPDevice &ctx, void LaunchBroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
Functor func, Functor func,
DimensionsTransform merge_dims) { DimensionsTransform merge_dims) {
int numel = (*outs)[0]->numel(); int numel = (*outs)[0]->numel();
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs; phi::Array<kps::details::BroadcastConfig, Arity> configs;
phi::Array<int, Arity> use_broadcast; phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data; phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data;
...@@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>( configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
} }
...@@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor, Functor,
Arity, Arity,
NumOuts, NumOuts,
VecSize, VecSize><<<blocks, threads, stream>>>(ins_data,
Rank><<<blocks, threads, stream>>>(ins_data, outs_data,
outs_data, use_broadcast,
use_broadcast, numel,
numel, configs,
configs, main_offset,
main_offset, tail_tid,
tail_tid, func);
func);
#else #else
const int threads = 256; const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
...@@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor, Functor,
Arity, Arity,
NumOuts, NumOuts,
VecSize, VecSize><<<blocks, threads, 0, stream>>>(
Rank><<<blocks, threads, 0, stream>>>(ins_data, ins_data,
outs_data, outs_data,
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
main_offset, main_offset,
tail_tid, tail_tid,
func); func);
#endif #endif
} }
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void BroadcastKernelForDifferentDimSize(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, outs, func, merge_dims); \
} break;
switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.",
merge_dims.dim_size,
phi::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}
template <ElementwiseType ET, template <ElementwiseType ET,
typename InT, typename InT,
typename OutT, typename OutT,
...@@ -506,33 +458,22 @@ void BroadcastKernelForDifferentVecSize( ...@@ -506,33 +458,22 @@ void BroadcastKernelForDifferentVecSize(
: in_vec_size; : in_vec_size;
} }
int vec_size = std::min(out_vec_size, in_vec_size); int vec_size = std::min(out_vec_size, in_vec_size);
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
switch (vec_size) { switch (vec_size) {
case 4: { case 4: {
BroadcastKernelForDifferentDimSize<InT, LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
OutT, ctx, ins, outs, func, merge_dims);
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
break; break;
} }
case 2: { case 2: {
BroadcastKernelForDifferentDimSize<InT, LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
OutT, ctx, ins, outs, func, merge_dims);
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
break; break;
} }
case 1: { case 1: {
BroadcastKernelForDifferentDimSize<InT, LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
OutT, ctx, ins, outs, func, merge_dims);
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
break; break;
} }
default: { default: {
......
...@@ -82,10 +82,10 @@ struct FastDivMod { ...@@ -82,10 +82,10 @@ struct FastDivMod {
* index of the output data. if input or output shape is [dim0, dim1] then dims * index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0]. * must be [dim1, dim0].
*/ */
template <int kDims>
struct BroadcastConfig { struct BroadcastConfig {
FastDivMod divmoders[kDims]; FastDivMod divmoders[phi::DDim::kMaxRank];
uint32_t strides[phi::DDim::kMaxRank]; uint32_t strides[phi::DDim::kMaxRank];
int kDims;
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
...@@ -109,7 +109,7 @@ struct BroadcastConfig { ...@@ -109,7 +109,7 @@ struct BroadcastConfig {
std::multiplies<int64_t>()) std::multiplies<int64_t>())
: strides_in[i]; : strides_in[i];
} }
kDims = dim_size;
memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t)); memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod)); memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
} }
...@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { ...@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
} }
} }
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
for (int i = 0; i < NX; i++) {
dst[i] = init_data;
}
}
/** /**
* The difference from the above function is that * The difference from the above function is that
* it supports different data types of inputs. * it supports different data types of inputs.
...@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
dst[idx] = src[thread_offset + idx];
}
}
} else { // blockDim,x * NX < num
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
}
}
}
}
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
...@@ -396,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, ...@@ -396,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim. * stride_ny: Each read one element stride stride_ny elements in the first dim.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __forceinline__ void ReadDataBc(
T* dst, T* dst,
const T* __restrict__ src, const T* __restrict__ src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, const details::BroadcastConfig& config,
int total_num_output, int total_num_output,
int stride_nx, int stride_nx,
int stride_ny) { int stride_ny) {
...@@ -425,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -425,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
#pragma unroll #pragma unroll
for (int i = 0; i < Rank; ++i) { for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output); auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0]; index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i]; index_src += fast_divmoder.val[1] * config.strides[i];
...@@ -576,6 +612,36 @@ __device__ __forceinline__ void WriteData(T* dst, ...@@ -576,6 +612,36 @@ __device__ __forceinline__ void WriteData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst,
T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) {
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((thread_offset + idx) < num) {
dst[thread_offset + idx] = src[idx];
}
}
} else {
// Vector type
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[thread_offset + idx] = vec_temp[idx];
}
}
}
/** /**
* @brief Write 2D data from register to global memory according to Tx type, and * @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type. * store it as Ty type.
...@@ -715,18 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { ...@@ -715,18 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output. * total_num_output: Total number of original output.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __forceinline__ void ReadDataBc(
T* dst, T* dst,
const T* __restrict__ src, const T* __restrict__ src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, const details::BroadcastConfig& config,
int total_num_output) { int total_num_output,
int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX; uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0; uint32_t index_src = 0;
...@@ -740,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -740,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
#pragma unroll #pragma unroll
for (int i = 0; i < Rank; ++i) { for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output); auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0]; index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i]; index_src += fast_divmoder.val[1] * config.strides[i];
......
...@@ -32,11 +32,11 @@ struct alignas(sizeof(T) * VecSize) VectorType { ...@@ -32,11 +32,11 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* must be [dim1, dim0]. * must be [dim1, dim0].
*/ */
#pragma pack(4) #pragma pack(4)
template <int kDims>
struct BroadcastConfig { struct BroadcastConfig {
int strides_in[phi::DDim::kMaxRank]; int strides_in[phi::DDim::kMaxRank];
int strides_out[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank];
int in_dim[phi::DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank];
int kDims;
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
...@@ -58,6 +58,7 @@ struct BroadcastConfig { ...@@ -58,6 +58,7 @@ struct BroadcastConfig {
dim_tmp[i] = in_dims[i]; dim_tmp[i] = in_dims[i];
} }
kDims = dim_size;
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
...@@ -328,16 +329,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, ...@@ -328,16 +329,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim. * stride_ny: Each read one element stride stride_ny elements in the first dim.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst, __device__ __inline__ void ReadDataBc(T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, details::BroadcastConfig config,
int total_num_output, int total_num_output,
int stride_nx, int stride_nx,
int stride_ny) { int stride_ny) {
...@@ -643,18 +639,12 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) { ...@@ -643,18 +639,12 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output. * total_num_output: Total number of original output.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX, __device__ __inline__ void ReadDataBc(T* dst,
int NY, const T _global_ptr_* src,
int BlockSize, uint32_t block_offset,
int Rank, const details::BroadcastConfig& config,
bool IsBoundary = false> int total_num_output) {
__device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output) {
int thread_offset = block_offset + core_id() * NX; int thread_offset = block_offset + core_id() * NX;
int index_src = 0; int index_src = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册