未验证 提交 8501fb00 编写于 作者: N niuliling123 提交者: GitHub

delete rank switch in broadcast_function.h for compile (#42645)

上级 8ffebb5a
...@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize, ...@@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary( __global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out, const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel, phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM> phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
configlists,
int main_tid, int tail_tid, Functor func) { int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize; int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid; int num = tail_tid;
...@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary( ...@@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(
// load in0 // load in0
if (use_broadcast[0]) { if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>( kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel); arg0, in0, fix, configlists[0], numel);
} else { } else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num); kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
} }
// load in1 // load in1
if (use_broadcast[1]) { if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>( kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel); arg1, in1, fix, configlists[1], numel);
} else { } else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num); kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
...@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, ...@@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads); int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads); int tail_tid = numel % (data_per_thread * vec_size * threads);
phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists; phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast; phi::Array<bool, MAX_INPUT_NUM> use_broadcast;
use_broadcast[0] = false; use_broadcast[0] = false;
...@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, ...@@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig. // Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1}; std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m}; std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2); configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);
auto func = AddFunctor<T>(); auto func = AddFunctor<T>();
auto stream = ctx.stream(); auto stream = ctx.stream();
......
...@@ -82,10 +82,10 @@ struct FastDivMod { ...@@ -82,10 +82,10 @@ struct FastDivMod {
* index of the output data. if input or output shape is [dim0, dim1] then dims * index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0]. * must be [dim1, dim0].
*/ */
template <int kDims>
struct BroadcastConfig { struct BroadcastConfig {
FastDivMod divmoders[kDims]; FastDivMod divmoders[phi::DDim::kMaxRank];
uint32_t strides[phi::DDim::kMaxRank]; uint32_t strides[phi::DDim::kMaxRank];
int kDims;
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
...@@ -109,7 +109,7 @@ struct BroadcastConfig { ...@@ -109,7 +109,7 @@ struct BroadcastConfig {
std::multiplies<int64_t>()) std::multiplies<int64_t>())
: strides_in[i]; : strides_in[i];
} }
kDims = dim_size;
memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t)); memcpy(strides, strides_in.data(), kDims * sizeof(uint32_t));
memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod)); memcpy(divmoders, divmoders_in.data(), kDims * sizeof(FastDivMod));
} }
...@@ -436,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, ...@@ -436,17 +436,12 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim. * stride_ny: Each read one element stride stride_ny elements in the first dim.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __forceinline__ void ReadDataBc(
T* dst, T* dst,
const T* __restrict__ src, const T* __restrict__ src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, const details::BroadcastConfig& config,
int total_num_output, int total_num_output,
int stride_nx, int stride_nx,
int stride_ny) { int stride_ny) {
...@@ -465,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -465,7 +460,8 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
#pragma unroll #pragma unroll
for (int i = 0; i < Rank; ++i) { for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output); auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0]; index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i]; index_src += fast_divmoder.val[1] * config.strides[i];
...@@ -785,53 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { ...@@ -785,53 +781,14 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output. * total_num_output: Total number of original output.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
int total_num_output) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
dst[nx] = src[index_src];
}
}
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __forceinline__ void ReadDataBc(
T* dst, T* dst,
const T* __restrict__ src, const T* __restrict__ src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, const details::BroadcastConfig& config,
int total_num_output, int total_num_output,
int read_lens) { int read_lens = NX) {
uint32_t thread_offset = block_offset + threadIdx.x * NX; uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0; uint32_t index_src = 0;
...@@ -845,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -845,7 +802,8 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
#pragma unroll #pragma unroll
for (int i = 0; i < Rank; ++i) { for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i >= config.kDims) break;
auto fast_divmoder = config.divmoders[i].Divmod(index_output); auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0]; index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i]; index_src += fast_divmoder.val[1] * config.strides[i];
...@@ -853,6 +811,7 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -853,6 +811,7 @@ __device__ __forceinline__ void ReadDataBc(
dst[nx] = src[index_src]; dst[nx] = src[index_src];
} }
} }
/** /**
* @brief Initialize register with data index. * @brief Initialize register with data index.
* *
......
...@@ -65,7 +65,6 @@ struct alignas(sizeof(T) * VecSize) VectorType { ...@@ -65,7 +65,6 @@ struct alignas(sizeof(T) * VecSize) VectorType {
* must be [dim1, dim0]. * must be [dim1, dim0].
*/ */
#pragma pack(4) #pragma pack(4)
template <int kDims>
struct BroadcastConfig { struct BroadcastConfig {
int strides_in[phi::DDim::kMaxRank]; int strides_in[phi::DDim::kMaxRank];
int strides_out[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank];
...@@ -78,7 +77,7 @@ struct BroadcastConfig { ...@@ -78,7 +77,7 @@ struct BroadcastConfig {
int n = 1; int n = 1;
int k = 1; int k = 1;
int buf_len = 0; int buf_len = 0;
int kDims;
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
...@@ -99,7 +98,7 @@ struct BroadcastConfig { ...@@ -99,7 +98,7 @@ struct BroadcastConfig {
for (int i = 0; i < dim_size; i++) { for (int i = 0; i < dim_size; i++) {
dim_tmp[i] = in_dims[i]; dim_tmp[i] = in_dims[i];
} }
kDims = dim_size;
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
...@@ -551,7 +550,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, ...@@ -551,7 +550,6 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* NY: The number of data rows loaded by each thread. * NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu, * BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index. * core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds * IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than * judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access * NX x NY x core_num(), boundary judgment is required to avoid memory access
...@@ -567,16 +565,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst, ...@@ -567,16 +565,11 @@ __device__ __forceinline__ void ReadData(ArgsT* dst,
* stride_nx: Each read one element stride stride_nx elements in the last dim. * stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim. * stride_ny: Each read one element stride stride_ny elements in the first dim.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst, __device__ __inline__ void ReadDataBc(T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
uint32_t block_offset, uint32_t block_offset,
details::BroadcastConfig<Rank> config, const details::BroadcastConfig& config,
int total_num_output, int total_num_output,
int stride_nx, int stride_nx,
int stride_ny) { int stride_ny) {
...@@ -882,60 +875,6 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) { ...@@ -882,60 +875,6 @@ __device__ __inline__ void Init(T* dst, T* init_data, int num) {
} }
} }
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output) {
int thread_offset = block_offset + core_id() * NX;
int index_src = 0;
__local__ T in_temp;
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
int index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
index_src = config(index_output);
GM2LM(src + index_src, &in_temp, sizeof(T));
dst[nx] = in_temp;
}
}
/** /**
* @brief Read data from global memory to local memory with broadcast * @brief Read data from global memory to local memory with broadcast
* {m, 1, k}-> {m, n, k} form. * {m, 1, k}-> {m, n, k} form.
...@@ -952,12 +891,12 @@ __device__ __inline__ void ReadDataBc( ...@@ -952,12 +891,12 @@ __device__ __inline__ void ReadDataBc(
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
*/ */
template <typename T, int Rank> template <typename T>
__device__ __inline__ void ReadDataBcM1kMnk( __device__ __inline__ void ReadDataBcM1kMnk(
T* dst, T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int thread_offset, int thread_offset,
const details::BroadcastConfig<Rank>& config, const details::BroadcastConfig& config,
int read_lens) { int read_lens) {
int index_output = thread_offset; int index_output = thread_offset;
int index_base = config(index_output); int index_base = config(index_output);
...@@ -999,12 +938,12 @@ __device__ __inline__ void ReadDataBcM1kMnk( ...@@ -999,12 +938,12 @@ __device__ __inline__ void ReadDataBcM1kMnk(
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
*/ */
template <typename T, int Rank> template <typename T>
__device__ __inline__ void ReadDataBcM1Mn( __device__ __inline__ void ReadDataBcM1Mn(
T* dst, T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int thread_offset, int thread_offset,
const details::BroadcastConfig<Rank>& config, const details::BroadcastConfig& config,
int read_lens) { int read_lens) {
int index_output = thread_offset; int index_output = thread_offset;
int index_base = config(index_output); int index_base = config(index_output);
...@@ -1027,7 +966,6 @@ __device__ __inline__ void ReadDataBcM1Mn( ...@@ -1027,7 +966,6 @@ __device__ __inline__ void ReadDataBcM1Mn(
* *
* @template paraments * @template paraments
* T: Data type of register. * T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* *
* @param: * @param:
* dst: The register pointer of the thread, the size is NX. * dst: The register pointer of the thread, the size is NX.
...@@ -1037,12 +975,12 @@ __device__ __inline__ void ReadDataBcM1Mn( ...@@ -1037,12 +975,12 @@ __device__ __inline__ void ReadDataBcM1Mn(
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
*/ */
template <typename T, int Rank> template <typename T>
__device__ __inline__ void ReadDataBc1NMn( __device__ __inline__ void ReadDataBc1NMn(
T* dst, T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int thread_offset, int thread_offset,
const details::BroadcastConfig<Rank>& config, const details::BroadcastConfig& config,
int read_lens) { int read_lens) {
int index_output = thread_offset; int index_output = thread_offset;
int index_base = config(index_output); int index_base = config(index_output);
...@@ -1075,7 +1013,6 @@ __device__ __inline__ void ReadDataBc1NMn( ...@@ -1075,7 +1013,6 @@ __device__ __inline__ void ReadDataBc1NMn(
* *
* @template paraments * @template paraments
* T: Data type of register. * T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* *
* @param: * @param:
* dst: The register pointer of the thread, the size is NX. * dst: The register pointer of the thread, the size is NX.
...@@ -1085,12 +1022,12 @@ __device__ __inline__ void ReadDataBc1NMn( ...@@ -1085,12 +1022,12 @@ __device__ __inline__ void ReadDataBc1NMn(
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
*/ */
template <typename T, int Rank> template <typename T>
__device__ __inline__ void ReadDataBc1N1Mnk( __device__ __inline__ void ReadDataBc1N1Mnk(
T* dst, T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int thread_offset, int thread_offset,
const details::BroadcastConfig<Rank>& config, const details::BroadcastConfig& config,
int read_lens) { int read_lens) {
int index_output = thread_offset; int index_output = thread_offset;
int index_base = config(index_output); int index_base = config(index_output);
...@@ -1130,7 +1067,6 @@ __device__ __inline__ void ReadDataBc1N1Mnk( ...@@ -1130,7 +1067,6 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
* *
* @template paraments * @template paraments
* T: Data type of register. * T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* *
* @param: * @param:
* dst: The register pointer of the thread, the size is NX. * dst: The register pointer of the thread, the size is NX.
...@@ -1140,13 +1076,12 @@ __device__ __inline__ void ReadDataBc1N1Mnk( ...@@ -1140,13 +1076,12 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
* coordinate mapping relationship between output data and input data. * coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
*/ */
template <typename T, int Rank> template <typename T>
__device__ __inline__ void ReadDataBc1N( __device__ __inline__ void ReadDataBc1N(T* dst,
T* dst, const T _global_ptr_* src,
const T _global_ptr_* src, int thread_offset,
int thread_offset, const details::BroadcastConfig& config,
const details::BroadcastConfig<Rank>& config, int read_lens) {
int read_lens) {
int index_output = thread_offset; int index_output = thread_offset;
int index_base = config(index_output); int index_base = config(index_output);
T in_temp; T in_temp;
...@@ -1174,12 +1109,12 @@ __device__ __inline__ void ReadDataBc1N( ...@@ -1174,12 +1109,12 @@ __device__ __inline__ void ReadDataBc1N(
* total_num_output: Total number of original output. * total_num_output: Total number of original output.
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
*/ */
template <typename T, int Rank, bool IsBoundary = false> template <typename T, bool IsBoundary = false>
__device__ __inline__ void ReadDataBcCanNotCmp( __device__ __inline__ void ReadDataBcCanNotCmp(
T* dst, T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int thread_offset, int thread_offset,
const details::BroadcastConfig<Rank>& config, const details::BroadcastConfig& config,
int total_num_output, int total_num_output,
int read_lens) { int read_lens) {
int index_output = thread_offset; int index_output = thread_offset;
...@@ -1215,7 +1150,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp( ...@@ -1215,7 +1150,6 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
* NY: The number of data rows loaded by each thread, only NY = 1 was supported. * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu, * BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index. * core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds * IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than * judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access * NX x NY x core_num(), boundary judgment is required to avoid memory access
...@@ -1230,33 +1164,27 @@ __device__ __inline__ void ReadDataBcCanNotCmp( ...@@ -1230,33 +1164,27 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
* read_lens: The number of data continuously loaded by each thread. * read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output. * total_num_output: Total number of original output.
*/ */
template <typename T, template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
int NX, __device__ __inline__ void ReadDataBc(T* dst,
int NY, const T _global_ptr_* src,
int BlockSize, uint32_t block_offset,
int Rank, const details::BroadcastConfig& config,
bool IsBoundary = false> int total_num_output,
__device__ __inline__ void ReadDataBc( int read_lens) {
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output,
int read_lens) {
int thread_offset = block_offset + core_id() * read_lens; int thread_offset = block_offset + core_id() * read_lens;
if (config.cmp_type == details::OptType::MNK_M1K) { if (config.cmp_type == details::OptType::MNK_M1K) {
ReadDataBcM1kMnk<T, Rank>(dst, src, thread_offset, config, read_lens); ReadDataBcM1kMnk<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::N_1) { } else if (config.cmp_type == details::OptType::N_1) {
ReadDataBc1N<T, Rank>(dst, src, thread_offset, config, read_lens); ReadDataBc1N<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_M) { } else if (config.cmp_type == details::OptType::MN_M) {
ReadDataBcM1Mn<T, Rank>(dst, src, thread_offset, config, read_lens); ReadDataBcM1Mn<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_N) { } else if (config.cmp_type == details::OptType::MN_N) {
ReadDataBc1NMn<T, Rank>(dst, src, thread_offset, config, read_lens); ReadDataBc1NMn<T>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MNK_1N1) { } else if (config.cmp_type == details::OptType::MNK_1N1) {
ReadDataBc1N1Mnk<T, Rank>(dst, src, thread_offset, config, read_lens); ReadDataBc1N1Mnk<T>(dst, src, thread_offset, config, read_lens);
} else { } else {
ReadDataBcCanNotCmp<T, Rank, IsBoundary>( ReadDataBcCanNotCmp<T, IsBoundary>(
dst, src, thread_offset, config, total_num_output, read_lens); dst, src, thread_offset, config, total_num_output, read_lens);
} }
} }
......
...@@ -40,7 +40,9 @@ ...@@ -40,7 +40,9 @@
#define GRID_NUM_X cluster_num() #define GRID_NUM_X cluster_num()
#define GRID_NUM_Y 0 #define GRID_NUM_Y 0
#define GRID_NUM_Z 0 #define GRID_NUM_Z 0
#define VecSizeL 512
#define VecSizeM 256
#define VecSizeS 128
#else #else
#define KPStream gpuStream_t #define KPStream gpuStream_t
...@@ -64,6 +66,9 @@ ...@@ -64,6 +66,9 @@
#define GRID_NUM_Y gridDim.y #define GRID_NUM_Y gridDim.y
#define GRID_NUM_Z gridDim.z #define GRID_NUM_Z gridDim.z
#define VecSizeL 4
#define VecSizeM 2
#define VecSizeS 1
#endif #endif
// include file // include file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册