未验证 提交 8501fb00 编写于 作者: N niuliling123 提交者: GitHub

delete rank switch in broadcast_function.h for compile (#42645)

上级 8ffebb5a
......@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid;
......@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(
// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
......@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads);
phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast;
use_broadcast[0] = false;
......@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);
auto func = AddFunctor<T>();
auto stream = ctx.stream();
......
......@@ -82,10 +82,10 @@ struct FastDivMod {
* index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0].
*/
template <int kDims>
struct BroadcastConfig {
FastDivMod divmoders[kDims];
FastDivMod divmoders[phi::DDim::kMaxRank];
uint32_t strides[phi::DDim::kMaxRank];
int kDims;
HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
......@@ -109,7 +109,7 @@ struct BroadcastConfig {
std::multiplies<int64_t>())
: strides_in[i];
}
kDims = dim_size;
memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
}
......@@ -436,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
const details::BroadcastConfig& config,
int total_num_output,
int stride_nx,
int stride_ny) {
......@@ -465,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc(
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
......@@ -785,53 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
int total_num_output) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
dst[nx] = src[index_src];
}
}
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens) {
int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
......@@ -845,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc(
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
......@@ -853,6 +811,7 @@ __device__ __forceinline__ void ReadDataBc(
dst[nx] = src[index_src];
}
}
/**
* @brief Initialize register with data index.
*
......
......@@ -65,7 +65,6 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* must be [dim1, dim0].
*/
#pragma pack(4)
template <int kDims>
struct BroadcastConfig {
int strides_in[phi::DDim::kMaxRank];
int strides_out[phi::DDim::kMaxRank];
......@@ -78,7 +77,7 @@ struct BroadcastConfig {
int n = 1;
int k = 1;
int buf_len = 0;
int kDims;
HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
......@@ -99,7 +98,7 @@ struct BroadcastConfig {
for (int i = 0; i < dim_size; i++) {
dim_tmp[i] = in_dims[i];
}
kDims = dim_size;
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
......@@ -551,7 +550,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
......@@ -567,16 +565,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
const details::BroadcastConfig& config,
int total_num_output,
int stride_nx,
int stride_ny) {
......@@ -882,60 +875,6 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) {
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output) {
int thread_offset = block_offset + core_id() * NX;
int index_src = 0;
__local__ T in_temp;
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
int index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
index_src = config(index_output);
GM2LM(src + index_src, &in_temp, sizeof(T));
dst[nx] = in_temp;
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1, k}-> {m, n, k} form.
......@@ -952,12 +891,12 @@ __device__ __inline__ void ReadDataBc(
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
template <typename T>
__device__ __inline__ void ReadDataBcM1kMnk(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
const details::BroadcastConfig& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
......@@ -999,12 +938,12 @@ __device__ __inline__ void ReadDataBcM1kMnk(
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
template <typename T>
__device__ __inline__ void ReadDataBcM1Mn(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
const details::BroadcastConfig& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
......@@ -1027,7 +966,6 @@ __device__ __inline__ void ReadDataBcM1Mn(
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
......@@ -1037,12 +975,12 @@ __device__ __inline__ void ReadDataBcM1Mn(
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
template <typename T>
__device__ __inline__ void ReadDataBc1NMn(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
const details::BroadcastConfig& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
......@@ -1075,7 +1013,6 @@ __device__ __inline__ void ReadDataBc1NMn(
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
......@@ -1085,12 +1022,12 @@ __device__ __inline__ void ReadDataBc1NMn(
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
template <typename T>
__device__ __inline__ void ReadDataBc1N1Mnk(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
const details::BroadcastConfig& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
......@@ -1130,7 +1067,6 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
......@@ -1140,13 +1076,12 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
__device__ __inline__ void ReadDataBc1N(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int read_lens) {
template <typename T>
__device__ __inline__ void ReadDataBc1N(T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
T in_temp;
......@@ -1174,12 +1109,12 @@ __device__ __inline__ void ReadDataBc1N(
* total_num_output: Total number of original output.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank, bool IsBoundary = false>
template <typename T, bool IsBoundary = false>
__device__ __inline__ void ReadDataBcCanNotCmp(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens) {
int index_output = thread_offset;
......@@ -1215,7 +1150,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
......@@ -1230,33 +1164,27 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output,
int read_lens) {
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens) {
int thread_offset = block_offset + core_id() * read_lens;
if (config.cmp_type == details::OptType::MNK_M1K) {
ReadDataBcM1kMnk<T, Rank>(dst, src, thread_offset, config, read_lens);
ReadDataBcM1kMnk<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::N_1) {
ReadDataBc1N<T, Rank>(dst, src, thread_offset, config, read_lens);
ReadDataBc1N<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_M) {
ReadDataBcM1Mn<T, Rank>(dst, src, thread_offset, config, read_lens);
ReadDataBcM1Mn<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_N) {
ReadDataBc1NMn<T, Rank>(dst, src, thread_offset, config, read_lens);
ReadDataBc1NMn<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MNK_1N1) {
ReadDataBc1N1Mnk<T, Rank>(dst, src, thread_offset, config, read_lens);
ReadDataBc1N1Mnk<T>(dst, src, thread_offset, config, read_lens);
} else {
ReadDataBcCanNotCmp<T, Rank, IsBoundary>(
ReadDataBcCanNotCmp<T, IsBoundary>(
dst, src, thread_offset, config, total_num_output, read_lens);
}
}
......
......@@ -40,7 +40,9 @@
#define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0
#define GRID_NUM_Z 0
#define VecSizeL 512
#define VecSizeM 256
#define VecSizeS 128
#else
#define KPStream gpuStream_t
......@@ -64,6 +66,9 @@
#define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z
#define VecSizeL 4
#define VecSizeM 2
#define VecSizeS 1
#endif
// include file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册