未验证 提交 c7855125 编写于 作者: S shixingbo 提交者: GitHub

broadcast_add kp performance optimization (#42097)

上级 81078a88
...@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData( ...@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
} }
} }
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
int numel,
int num,
int need_broadcast,
int read_lens) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
dst, src, block_offset, config, numel, read_lens);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(
dst, src + block_offset, num, read_lens);
}
}
template <typename InT, template <typename InT,
typename OutT, typename OutT,
typename Functor, typename Functor,
...@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs, const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
int num, int num,
int block_offset, int block_offset,
int read_lens,
Functor func) { Functor func) {
InT args[Arity][VecSize]; __simd__ InT args[Arity][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize]; __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
LoadData<InT, VecSize, Rank, IsBoundary>(args[i], LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i], ins[i],
block_offset, block_offset,
configs[i], configs[i],
numel, numel,
num, num,
use_broadcast[i]); use_broadcast[i],
read_lens);
} }
constexpr bool kCallElementwiseAny = constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args; paddle::platform::FunctionTraits<Functor>::has_pointer_args;
...@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor, Functor,
Arity, Arity,
kCallElementwiseAny>()( kCallElementwiseAny>()(
func, args, result); func, args, result, read_lens);
phi::funcs::
phi::funcs::ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()( ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num); outs, result, block_offset, num, read_lens);
} }
template <typename InT, template <typename InT,
...@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel( ...@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs, phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
int main_offset, int main_offset,
int tail_tid, int tail_tid,
int read_lens,
Functor func) { Functor func) {
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
for (; block_offset < main_offset; block_offset += stride) { for (; block_offset < main_offset; block_offset += stride) {
...@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel( ...@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
BLOCK_NUM_X * VecSize, BLOCK_NUM_X * read_lens,
block_offset, block_offset,
read_lens,
func); func);
} }
int num = numel - block_offset; int num = numel - block_offset;
...@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel( ...@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts, NumOuts,
VecSize, VecSize,
Rank, Rank,
true>( true>(ins,
ins, outs, use_broadcast, numel, configs, num, block_offset, func); outs,
use_broadcast,
numel,
configs,
num,
block_offset,
read_lens,
func);
} }
#else #else
if (block_offset < main_offset) { if (block_offset < main_offset) {
...@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs, configs,
BLOCK_NUM_X * VecSize, BLOCK_NUM_X * VecSize,
block_offset, block_offset,
read_lens,
func); func);
} else { } else {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<InT,
...@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel( ...@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts, NumOuts,
VecSize, VecSize,
Rank, Rank,
true>( true>(ins,
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); outs,
use_broadcast,
numel,
configs,
tail_tid,
block_offset,
read_lens,
func);
} }
#endif #endif
} }
...@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>()); ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
#ifdef PADDLE_WITH_XPU_KP
if (i == 0) {
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
merge_dims.in_dims[0],
merge_dims.in_dims[1],
merge_dims.dim_size);
} else if (i == 1) {
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
merge_dims.in_dims[1],
merge_dims.in_dims[0],
merge_dims.dim_size);
}
#else
if (use_broadcast[i]) { if (use_broadcast[i]) {
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
...@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs[i] = kps::details::BroadcastConfig<Rank>( configs[i] = kps::details::BroadcastConfig<Rank>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
#endif
} }
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
const int threads = 64; const int threads = 64;
const int blocks = 8; const int blocks = 8;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int read_lens = configs[0].buf_len;
int tail_tid = numel % (VecSize * threads); int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);
auto stream = ctx.x_context()->xpu_stream; auto stream = ctx.x_context()->xpu_stream;
VectorizedBroadcastKernel<InT, if (configs[0].cmp_type != kps::details::OptType::CanNotOptimize) {
OutT, main_offset = numel;
Functor, VectorizedBroadcastKernel<InT,
Arity, OutT,
NumOuts, Functor,
VecSize, Arity,
Rank><<<blocks, threads, stream>>>(ins_data, NumOuts,
outs_data, 512,
use_broadcast, Rank><<<blocks, threads, stream>>>(ins_data,
numel, outs_data,
configs, use_broadcast,
main_offset, numel,
tail_tid, configs,
func); main_offset,
tail_tid,
read_lens,
func);
} else {
VectorizedBroadcastKernel<InT,
OutT,
Functor,
Arity,
NumOuts,
256,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
}
#else #else
const int threads = 256; const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
...@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs, configs,
main_offset, main_offset,
tail_tid, tail_tid,
VecSize,
func); func);
#endif #endif
} }
......
...@@ -577,14 +577,16 @@ template <typename InT, ...@@ -577,14 +577,16 @@ template <typename InT,
struct ElementwisePrimitiveCaller { struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result); OutT *result,
int read_lens);
}; };
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity> template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>( kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
result, args, func); result, args, func);
} }
...@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func); kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
} }
}; };
...@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], func); result, args[0], func);
} }
...@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], func); result, args[0], args[1], func, read_lens);
} }
}; };
...@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func); result, args[0], args[1], args[2], func);
} }
...@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> { ...@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
} }
}; };
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCallerBc {
__device__ __forceinline__ void operator()(
phi::Array<_ptr_ OutT *, NumOuts> outs,
ConditionalT<OutT, NumOuts> src[VecSize],
int block_offset,
int num,
int read_lens) {
OutT dst[NumOuts][VecSize];
#pragma unroll
for (int i = 0; i < read_lens; ++i) {
#pragma unroll
for (int j = 0; j < NumOuts; ++j) {
dst[j][i] = (src[i])[j];
}
}
#pragma unroll
for (int i = 0; i < NumOuts; ++i) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[i] + block_offset, dst[i], num, read_lens);
}
}
};
template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, 1> {
__device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs,
OutT src[VecSize],
int block_offset,
int num,
int read_lens) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[0] + block_offset, src, num, read_lens);
}
};
template <typename OutT, template <typename OutT,
typename Functor, typename Functor,
int Arity, int Arity,
......
...@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, ...@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
} }
} }
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
}
}
/** /**
* @brief Ternary calculation according to OpFunc. Shape of input and output * @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same. * are the same.
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h" #include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h" #include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h" #include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace phi { namespace phi {
namespace kps { namespace kps {
...@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, ...@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
} }
} }
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
for (int idx = 0; idx < read_lens; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
}
}
/** /**
* @brief Ternary calculation according to OpFunc. Shape of input and output * @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same. * are the same.
......
...@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { ...@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
} }
} }
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
for (int i = 0; i < NX; i++) {
dst[i] = init_data;
}
}
/** /**
* The difference from the above function is that * The difference from the above function is that
* it supports different data types of inputs. * it supports different data types of inputs.
...@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
dst[idx] = src[thread_offset + idx];
}
}
} else { // blockDim,x * NX < num
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
}
}
}
}
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
...@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst, ...@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst,
T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) {
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((thread_offset + idx) < num) {
dst[thread_offset + idx] = src[idx];
}
}
} else {
// Vector type
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[thread_offset + idx] = vec_temp[idx];
}
}
}
/** /**
* @brief Write 2D data from register to global memory according to Tx type, and * @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type. * store it as Ty type.
...@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
int total_num_output,
int read_lens) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
dst[nx] = src[index_src];
}
}
/** /**
* @brief Initialize register with data index. * @brief Initialize register with data index.
* *
......
...@@ -21,6 +21,39 @@ namespace phi { ...@@ -21,6 +21,39 @@ namespace phi {
namespace kps { namespace kps {
namespace details { namespace details {
enum class OptType { // Optimize type of calc after input shape compressed
CanNotOptimize = -1, // can not optimize, broadcast first
N_1, // just like {1} op {100} or {100} op {1}
MN_N, // just like {100} op {3, 100} or {3, 100} op {100}
MN_M, // just like {3} op {3, 100} or {3, 100} op {3}
MNK_1N1, // just like {3} op {2, 3, 100} or {2, 3, 100} op {3}
MNK_M1K, // just like {2, 1, 100} op {2, 3, 100} or {2, 3, 100} op {2, 1,
// 100}
};
// Rules to determine whether dimensions can be merged
// rule 0 - xshape[idx] == yshape[idx]
// rule 1 - xshape[idx] == 1 && yshape[idx] != 1
// rule 2 - xshape[idx] != 1 && yshape[idx] == 1
static int judge_case(int a, int b) {
if (a == b) {
return 0;
} else if (a == 1 && b != 1) {
return 1;
} else if (a != 1 && b == 1) {
return 2;
}
return -1;
}
static bool case_is_same(int case_front, int case_back) {
if (case_front == case_back) {
return true;
} else {
return false;
}
}
template <typename T, int VecSize> template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType { struct alignas(sizeof(T) * VecSize) VectorType {
T val[VecSize]; T val[VecSize];
...@@ -37,11 +70,20 @@ struct BroadcastConfig { ...@@ -37,11 +70,20 @@ struct BroadcastConfig {
int strides_in[phi::DDim::kMaxRank]; int strides_in[phi::DDim::kMaxRank];
int strides_out[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank];
int in_dim[phi::DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank];
int dim_after_cmp[phi::DDim::kMaxRank];
int dim_size_after_cmp = 0;
int cmp_res = 0;
OptType cmp_type = OptType::CanNotOptimize;
int m = 1;
int n = 1;
int k = 1;
int buf_len = 0;
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
const std::vector<int64_t>& in_dims, const std::vector<int64_t>& in_dims,
const std::vector<int64_t>& another_in_dims,
int dim_size) { int dim_size) {
std::vector<int> strides_in_tmp; std::vector<int> strides_in_tmp;
std::vector<int> strides_out_tmp; std::vector<int> strides_out_tmp;
...@@ -61,18 +103,187 @@ struct BroadcastConfig { ...@@ -61,18 +103,187 @@ struct BroadcastConfig {
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
cmp_res = get_mnk_for_broadcast_ops(in_dims, another_in_dims);
get_opt_type(another_in_dims);
buf_len = get_buf_len();
}
int get_buf_len() {
if (cmp_type == OptType::CanNotOptimize) {
return 256;
}
int max_buf_len = 512;
int buf_len = m / 16 * 16;
if (buf_len == 0) {
buf_len = m;
}
return std::min(max_buf_len, buf_len);
} }
__device__ inline int operator()(int index_output) const { __device__ inline int operator()(int index_output) const {
int index_src = 0; int index_src = 0;
#pragma unroll
for (int i = kDims - 1; i >= 0; --i) { switch (cmp_type) {
int tmp_index = (index_output / strides_out[i]); int div, mod, tmp_index;
index_output = index_output - tmp_index * strides_out[i]; case OptType::MNK_M1K:
index_src += (tmp_index % in_dim[i]) * strides_in[i]; div = index_output / (m * n);
mod = index_output % (m * n) % m;
index_src = div * m + mod;
break;
case OptType::MNK_1N1:
// index_src = index_output / m % n;
index_src = index_output % (m * n) / m;
break;
case OptType::N_1:
index_src = 0;
break;
case OptType::MN_N:
index_src = index_output / m;
break;
case OptType::MN_M:
index_src = index_output % m;
break;
case OptType::CanNotOptimize:
for (int i = kDims - 1; i >= 0; --i) {
tmp_index = (index_output / strides_out[i]);
index_output = index_output - tmp_index * strides_out[i];
index_src += (tmp_index % in_dim[i]) * strides_in[i];
}
break;
} }
return index_src; return index_src;
} }
void get_opt_type(const std::vector<int64_t>& y_dim_after_cmp) {
if (dim_size_after_cmp == 1) {
if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) { // {1} op {n}
n = y_dim_after_cmp[0];
cmp_type = OptType::N_1;
} else if (dim_after_cmp[0] != 1 &&
y_dim_after_cmp[0] == 1) { // {n} op {1}
n = dim_after_cmp[0];
cmp_type = OptType::N_1;
} else {
cmp_type = OptType::CanNotOptimize; // xshape == yshape
}
}
if (dim_size_after_cmp == 2) {
if (dim_after_cmp[0] == 1 && dim_after_cmp[1] != 1 &&
y_dim_after_cmp[0] != 1 &&
y_dim_after_cmp[1] != 1) { // {n} op {m, n}
m = y_dim_after_cmp[0];
n = y_dim_after_cmp[1];
cmp_type = OptType::MN_N;
} else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] == 1 &&
y_dim_after_cmp[0] != 1 &&
y_dim_after_cmp[1] != 1) { // {m} op {m, n}
m = y_dim_after_cmp[0];
n = y_dim_after_cmp[1];
cmp_type = OptType::MN_M;
} else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
y_dim_after_cmp[0] == 1 &&
y_dim_after_cmp[1] != 1) { // {m, n} op {n}
m = dim_after_cmp[0];
n = dim_after_cmp[1];
cmp_type = OptType::MN_N;
} else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
y_dim_after_cmp[0] != 1 &&
y_dim_after_cmp[1] == 1) { // {m, n} op {m}
m = dim_after_cmp[0];
n = dim_after_cmp[1];
cmp_type = OptType::MN_M;
} else {
cmp_type = OptType::CanNotOptimize;
}
}
if (dim_size_after_cmp == 3) {
if (dim_after_cmp[0] == 1 && dim_after_cmp[1] != 1 &&
dim_after_cmp[2] == 1 && y_dim_after_cmp[0] != 1 &&
y_dim_after_cmp[1] != 1 &&
y_dim_after_cmp[2] != 1) { // {1, n, 1} op {m, n, k}
m = y_dim_after_cmp[0];
n = y_dim_after_cmp[1];
k = y_dim_after_cmp[2];
cmp_type = OptType::MNK_1N1;
} else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
dim_after_cmp[2] != 1 && y_dim_after_cmp[0] == 1 &&
y_dim_after_cmp[1] != 1 &&
y_dim_after_cmp[2] == 1) { // {m, n, k} op {1, n, 1}
m = dim_after_cmp[0];
n = dim_after_cmp[1];
k = dim_after_cmp[2];
cmp_type = OptType::MNK_1N1;
} else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] == 1 &&
dim_after_cmp[2] != 1 && y_dim_after_cmp[0] != 1 &&
y_dim_after_cmp[1] != 1 &&
y_dim_after_cmp[2] != 1) { // {m, 1, k} op {m, n, k}
m = y_dim_after_cmp[0];
n = y_dim_after_cmp[1];
k = y_dim_after_cmp[2];
cmp_type = OptType::MNK_M1K;
} else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
dim_after_cmp[2] != 1 && y_dim_after_cmp[0] != 1 &&
y_dim_after_cmp[1] == 1 &&
y_dim_after_cmp[2] != 1) { // {m, n, k} op {m, 1, k}
m = dim_after_cmp[0];
n = dim_after_cmp[1];
k = dim_after_cmp[2];
cmp_type = OptType::MNK_M1K;
} else {
cmp_type = OptType::CanNotOptimize;
}
}
}
int get_mnk_for_broadcast_ops(const std::vector<int64_t>& xshape,
const std::vector<int64_t>& yshape) {
int idx = 0;
int cmp_x = 0;
int cmp_y = 0;
bool is_same = false;
std::vector<int64_t> xshape_after_remove_ones = xshape;
std::vector<int64_t> yshape_after_remove_ones = yshape;
// first step: remove excess ones
std::vector<int64_t>::iterator x_iter = xshape_after_remove_ones.begin();
std::vector<int64_t>::iterator y_iter = yshape_after_remove_ones.begin();
for (; x_iter != xshape_after_remove_ones.end();) {
if (*x_iter == 1 && *y_iter == 1) {
x_iter = xshape_after_remove_ones.erase(x_iter);
y_iter = yshape_after_remove_ones.erase(y_iter);
} else {
x_iter++;
y_iter++;
}
}
// second step: compress dims
int after_cmp_idx = 0;
for (int i = 0; i < 3; i++) {
cmp_x = xshape_after_remove_ones[idx];
cmp_y = yshape_after_remove_ones[idx];
while ((idx + 1) < xshape_after_remove_ones.size()) {
is_same = case_is_same(judge_case(xshape_after_remove_ones[idx],
yshape_after_remove_ones[idx]),
judge_case(xshape_after_remove_ones[idx + 1],
yshape_after_remove_ones[idx + 1]));
if (is_same) {
cmp_x = cmp_x * xshape_after_remove_ones[idx + 1];
cmp_y = cmp_y * yshape_after_remove_ones[idx + 1];
idx++;
} else {
break;
}
}
idx = idx + 1;
dim_after_cmp[after_cmp_idx] = cmp_x;
after_cmp_idx++;
if (idx == xshape_after_remove_ones.size()) {
dim_size_after_cmp = after_cmp_idx;
return 0;
}
}
return -1; // can not compress dims
}
}; };
#pragma pack() #pragma pack()
...@@ -199,6 +410,14 @@ __device__ __inline__ void Init(T* dst, T init_data) { ...@@ -199,6 +410,14 @@ __device__ __inline__ void Init(T* dst, T init_data) {
} }
} }
template <typename T, int NX>
__device__ __inline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
for (int i = 0; i < read_lens; i++) {
dst[i] = init_data;
}
}
/** /**
* The difference from the above function is that * The difference from the above function is that
* it supports different data types of inputs. * it supports different data types of inputs.
...@@ -251,6 +470,26 @@ __device__ __inline__ void ReadData(T* dst, ...@@ -251,6 +470,26 @@ __device__ __inline__ void ReadData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
}
}
} else { // core_num() * read_lens < num
GM2LM(src + thread_offset, dst, read_lens * sizeof(T));
}
}
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
...@@ -479,10 +718,32 @@ __device__ __forceinline__ void ReadDataReduce( ...@@ -479,10 +718,32 @@ __device__ __forceinline__ void ReadDataReduce(
* size: The current block needs to load size elements continuously. * size: The current block needs to load size elements continuously.
*/ */
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst,
const T* src,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * read_lens < num
LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
}
}
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary> template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
int thread_offset = core_id() * NX; int thread_offset = core_id() * NX;
__local__ T in_temp[1]; __local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num if (IsBoundary) { // core_num() * NX > num
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
...@@ -675,6 +936,331 @@ __device__ __inline__ void ReadDataBc( ...@@ -675,6 +936,331 @@ __device__ __inline__ void ReadDataBc(
} }
} }
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1, k}-> {m, n, k} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
__device__ __inline__ void ReadDataBcM1kMnk(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
int m = config.m;
int n = config.n;
int m_pos = index_base % m;
if ((m - m_pos) < read_lens) {
int last_col = m - m_pos;
GM2LM(src + index_base, dst, last_col * sizeof(T));
int n_pos = index_output % (m * n) / m;
int next_part_index = 0;
if (n_pos != config.n - 1) {
next_part_index = index_base / m * m;
} else {
next_part_index = (index_base / m + 1) * m;
}
GM2LM(src + next_part_index,
dst + last_col,
(read_lens - last_col) * sizeof(T));
} else {
GM2LM(src + index_base, dst, read_lens * sizeof(T));
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1}-> {m, n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
__device__ __inline__ void ReadDataBcM1Mn(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
int m = config.m;
int n = config.n;
int m_pos = index_base % m;
if ((m - m_pos) < read_lens) {
int last_col = m - m_pos;
GM2LM(src + index_base, dst, last_col * sizeof(T));
GM2LM(src, dst + last_col, (read_lens - last_col) * sizeof(T));
} else {
GM2LM(src + index_base, dst, read_lens * sizeof(T));
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1, n}-> {m, n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
__device__ __inline__ void ReadDataBc1NMn(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
int m = config.m;
int n = config.n;
T in_temp;
int m_pos = index_output % m;
if ((m - m_pos) < read_lens) {
int last_col = m - m_pos;
GM2LM(src + index_base, &in_temp, sizeof(T));
for (int i = 0; i < last_col; i++) {
dst[i] = in_temp;
}
GM2LM(src + index_base + 1, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
}
} else {
GM2LM(src + index_base, &in_temp, sizeof(T));
for (int i = 0; i < read_lens; i++) {
dst[i] = in_temp;
}
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1, n, 1}-> {m, n, k} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
__device__ __inline__ void ReadDataBc1N1Mnk(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
int m = config.m;
int n = config.n;
T in_temp;
int m_pos = index_output % m;
if ((m - m_pos) < read_lens) {
int last_col = m - m_pos;
GM2LM(src + index_base, &in_temp, sizeof(T));
for (int i = 0; i < last_col; i++) {
dst[i] = in_temp;
}
int n_pos = index_output % (m * n) / m;
int next_part_index = 0;
if (n_pos != n - 1) {
next_part_index = n_pos + 1;
} else {
next_part_index = 0;
}
GM2LM(src + next_part_index, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
}
} else {
GM2LM(src + index_base, &in_temp, sizeof(T));
for (int i = 0; i < read_lens; i++) {
dst[i] = in_temp;
}
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1}-> {n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank>
__device__ __inline__ void ReadDataBc1N(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
T in_temp;
GM2LM(src + index_base, &in_temp, sizeof(T));
for (int i = 0; i < read_lens; i++) {
dst[i] = in_temp;
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* form which can not compress.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
* read_lens: The number of data continuously loaded by each thread.
*/
template <typename T, int Rank, bool IsBoundary = false>
__device__ __inline__ void ReadDataBcCanNotCmp(
T* dst,
const T _global_ptr_* src,
int thread_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output,
int read_lens) {
int index_output = thread_offset;
int index_base = config(index_output);
T in_temp;
int cache_size = 256;
__local__ T src_temp[cache_size];
GM2LM(src + index_base, src_temp, cache_size * sizeof(T));
for (int nx = 0; nx < read_lens; ++nx) {
index_output = thread_offset + nx;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
int index_src = config(index_output);
if (index_src >= index_base && index_src < index_base + cache_size) {
in_temp = src_temp[index_src - index_base];
} else {
GM2LM(src + index_src, &in_temp, sizeof(T));
}
dst[nx] = in_temp;
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(
T* dst,
const T _global_ptr_* src,
uint32_t block_offset,
const details::BroadcastConfig<Rank>& config,
int total_num_output,
int read_lens) {
int thread_offset = block_offset + core_id() * read_lens;
if (config.cmp_type == details::OptType::MNK_M1K) {
ReadDataBcM1kMnk<T, Rank>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::N_1) {
ReadDataBc1N<T, Rank>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_M) {
ReadDataBcM1Mn<T, Rank>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_N) {
ReadDataBc1NMn<T, Rank>(dst, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MNK_1N1) {
ReadDataBc1N1Mnk<T, Rank>(dst, src, thread_offset, config, read_lens);
} else {
ReadDataBcCanNotCmp<T, Rank, IsBoundary>(
dst, src, thread_offset, config, total_num_output, read_lens);
}
}
/** /**
* @brief Initialize register with data index. * @brief Initialize register with data index.
* *
......
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#define KPStream gpuStream_t #define KPStream gpuStream_t
#define KPDevice phi::GPUContext #define KPDevice phi::GPUContext
#define _ptr_ #define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x #define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y #define THREAD_ID_Y threadIdx.y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册