未验证 提交 c7855125 编写于 作者: S shixingbo 提交者: GitHub

broadcast_add kp performance optimization (#42097)

上级 81078a88
...@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData( ...@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
} }
} }
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
int numel,
int num,
int need_broadcast,
int read_lens) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
dst, src, block_offset, config, numel, read_lens);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(
dst, src + block_offset, num, read_lens);
}
}
template <typename InT, template <typename InT,
typename OutT, typename OutT,
typename Functor, typename Functor,
...@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs, const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
int num, int num,
int block_offset, int block_offset,
int read_lens,
Functor func) { Functor func) {
InT args[Arity][VecSize]; __simd__ InT args[Arity][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize]; __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
LoadData<InT, VecSize, Rank, IsBoundary>(args[i], LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i], ins[i],
block_offset, block_offset,
configs[i], configs[i],
numel, numel,
num, num,
use_broadcast[i]); use_broadcast[i],
read_lens);
} }
constexpr bool kCallElementwiseAny = constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args; paddle::platform::FunctionTraits<Functor>::has_pointer_args;
...@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor, Functor,
Arity, Arity,
kCallElementwiseAny>()( kCallElementwiseAny>()(
func, args, result); func, args, result, read_lens);
phi::funcs::
phi::funcs::ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()( ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num); outs, result, block_offset, num, read_lens);
} }
template <typename InT, template <typename InT,
...@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel( ...@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs, phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
int main_offset, int main_offset,
int tail_tid, int tail_tid,
int read_lens,
Functor func) { Functor func) {
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
for (; block_offset < main_offset; block_offset += stride) { for (; block_offset < main_offset; block_offset += stride) {
...@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel( ...@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast, use_broadcast,
numel, numel,
configs, configs,
BLOCK_NUM_X * VecSize, BLOCK_NUM_X * read_lens,
block_offset, block_offset,
read_lens,
func); func);
} }
int num = numel - block_offset; int num = numel - block_offset;
...@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel( ...@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts, NumOuts,
VecSize, VecSize,
Rank, Rank,
true>( true>(ins,
ins, outs, use_broadcast, numel, configs, num, block_offset, func); outs,
use_broadcast,
numel,
configs,
num,
block_offset,
read_lens,
func);
} }
#else #else
if (block_offset < main_offset) { if (block_offset < main_offset) {
...@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel( ...@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs, configs,
BLOCK_NUM_X * VecSize, BLOCK_NUM_X * VecSize,
block_offset, block_offset,
read_lens,
func); func);
} else { } else {
VectorizedBroadcastKernelImpl<InT, VectorizedBroadcastKernelImpl<InT,
...@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel( ...@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts, NumOuts,
VecSize, VecSize,
Rank, Rank,
true>( true>(ins,
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); outs,
use_broadcast,
numel,
configs,
tail_tid,
block_offset,
read_lens,
func);
} }
#endif #endif
} }
...@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>()); ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
#ifdef PADDLE_WITH_XPU_KP
if (i == 0) {
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
merge_dims.in_dims[0],
merge_dims.in_dims[1],
merge_dims.dim_size);
} else if (i == 1) {
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
merge_dims.in_dims[1],
merge_dims.in_dims[0],
merge_dims.dim_size);
}
#else
if (use_broadcast[i]) { if (use_broadcast[i]) {
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
...@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs[i] = kps::details::BroadcastConfig<Rank>( configs[i] = kps::details::BroadcastConfig<Rank>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
#endif
} }
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
const int threads = 64; const int threads = 64;
const int blocks = 8; const int blocks = 8;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int read_lens = configs[0].buf_len;
int tail_tid = numel % (VecSize * threads); int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);
auto stream = ctx.x_context()->xpu_stream; auto stream = ctx.x_context()->xpu_stream;
VectorizedBroadcastKernel<InT, if (configs[0].cmp_type != kps::details::OptType::CanNotOptimize) {
OutT, main_offset = numel;
Functor, VectorizedBroadcastKernel<InT,
Arity, OutT,
NumOuts, Functor,
VecSize, Arity,
Rank><<<blocks, threads, stream>>>(ins_data, NumOuts,
outs_data, 512,
use_broadcast, Rank><<<blocks, threads, stream>>>(ins_data,
numel, outs_data,
configs, use_broadcast,
main_offset, numel,
tail_tid, configs,
func); main_offset,
tail_tid,
read_lens,
func);
} else {
VectorizedBroadcastKernel<InT,
OutT,
Functor,
Arity,
NumOuts,
256,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
}
#else #else
const int threads = 256; const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
...@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx, ...@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs, configs,
main_offset, main_offset,
tail_tid, tail_tid,
VecSize,
func); func);
#endif #endif
} }
......
...@@ -577,14 +577,16 @@ template <typename InT, ...@@ -577,14 +577,16 @@ template <typename InT,
struct ElementwisePrimitiveCaller { struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result); OutT *result,
int read_lens);
}; };
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity> template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>( kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
result, args, func); result, args, func);
} }
...@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func); kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
} }
}; };
...@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], func); result, args[0], func);
} }
...@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], func); result, args[0], args[1], func, read_lens);
} }
}; };
...@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor> ...@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func, __device__ inline void operator()(Functor func,
InT (*args)[VecSize], InT (*args)[VecSize],
OutT *result) { OutT *result,
int read_lens) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func); result, args[0], args[1], args[2], func);
} }
...@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> { ...@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
} }
}; };
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCallerBc {
__device__ __forceinline__ void operator()(
phi::Array<_ptr_ OutT *, NumOuts> outs,
ConditionalT<OutT, NumOuts> src[VecSize],
int block_offset,
int num,
int read_lens) {
OutT dst[NumOuts][VecSize];
#pragma unroll
for (int i = 0; i < read_lens; ++i) {
#pragma unroll
for (int j = 0; j < NumOuts; ++j) {
dst[j][i] = (src[i])[j];
}
}
#pragma unroll
for (int i = 0; i < NumOuts; ++i) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[i] + block_offset, dst[i], num, read_lens);
}
}
};
template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, 1> {
__device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs,
OutT src[VecSize],
int block_offset,
int num,
int read_lens) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[0] + block_offset, src, num, read_lens);
}
};
template <typename OutT, template <typename OutT,
typename Functor, typename Functor,
int Arity, int Arity,
......
...@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, ...@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
} }
} }
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
}
}
/** /**
* @brief Ternary calculation according to OpFunc. Shape of input and output * @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same. * are the same.
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h" #include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h" #include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h" #include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace phi { namespace phi {
namespace kps { namespace kps {
...@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, ...@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
} }
} }
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
for (int idx = 0; idx < read_lens; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
}
}
/** /**
* @brief Ternary calculation according to OpFunc. Shape of input and output * @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same. * are the same.
......
...@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { ...@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
} }
} }
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
for (int i = 0; i < NX; i++) {
dst[i] = init_data;
}
}
/** /**
* The difference from the above function is that * The difference from the above function is that
* it supports different data types of inputs. * it supports different data types of inputs.
...@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
dst[idx] = src[thread_offset + idx];
}
}
} else { // blockDim,x * NX < num
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
}
}
}
}
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
...@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst, ...@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
} }
} }
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst,
T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) {
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((thread_offset + idx) < num) {
dst[thread_offset + idx] = src[idx];
}
}
} else {
// Vector type
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[thread_offset + idx] = vec_temp[idx];
}
}
}
/** /**
* @brief Write 2D data from register to global memory according to Tx type, and * @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type. * store it as Ty type.
...@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
int total_num_output,
int read_lens) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
dst[nx] = src[index_src];
}
}
/** /**
* @brief Initialize register with data index. * @brief Initialize register with data index.
* *
......
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#define KPStream gpuStream_t #define KPStream gpuStream_t
#define KPDevice phi::GPUContext #define KPDevice phi::GPUContext
#define _ptr_ #define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x #define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y #define THREAD_ID_Y threadIdx.y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册