未验证 提交 c7855125 编写于 作者: S shixingbo 提交者: GitHub

broadcast_add kp performance optimization (#42097)

上级 81078a88
......@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
}
}
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
int numel,
int num,
int need_broadcast,
int read_lens) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
dst, src, block_offset, config, numel, read_lens);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(
dst, src + block_offset, num, read_lens);
}
}
template <typename InT,
typename OutT,
typename Functor,
......@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
int num,
int block_offset,
int read_lens,
Functor func) {
InT args[Arity][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];
__simd__ InT args[Arity][VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
use_broadcast[i],
read_lens);
}
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
......@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor,
Arity,
kCallElementwiseAny>()(
func, args, result);
phi::funcs::ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num);
func, args, result, read_lens);
phi::funcs::
ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
outs, result, block_offset, num, read_lens);
}
template <typename InT,
......@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
int main_offset,
int tail_tid,
int read_lens,
Functor func) {
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
#ifdef PADDLE_WITH_XPU_KP
for (; block_offset < main_offset; block_offset += stride) {
......@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
BLOCK_NUM_X * read_lens,
block_offset,
read_lens,
func);
}
int num = numel - block_offset;
......@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, num, block_offset, func);
true>(ins,
outs,
use_broadcast,
numel,
configs,
num,
block_offset,
read_lens,
func);
}
#else
if (block_offset < main_offset) {
......@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs,
BLOCK_NUM_X * VecSize,
block_offset,
read_lens,
func);
} else {
VectorizedBroadcastKernelImpl<InT,
......@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
true>(ins,
outs,
use_broadcast,
numel,
configs,
tail_tid,
block_offset,
read_lens,
func);
}
#endif
}
......@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
#ifdef PADDLE_WITH_XPU_KP
if (i == 0) {
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
merge_dims.in_dims[0],
merge_dims.in_dims[1],
merge_dims.dim_size);
} else if (i == 1) {
configs[i] = kps::details::BroadcastConfig<Rank>(merge_dims.out_dims,
merge_dims.in_dims[1],
merge_dims.in_dims[0],
merge_dims.dim_size);
}
#else
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
......@@ -399,20 +452,24 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs[i] = kps::details::BroadcastConfig<Rank>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
#endif
}
#ifdef PADDLE_WITH_XPU_KP
const int threads = 64;
const int blocks = 8;
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
int read_lens = configs[0].buf_len;
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);
auto stream = ctx.x_context()->xpu_stream;
if (configs[0].cmp_type != kps::details::OptType::CanNotOptimize) {
main_offset = numel;
VectorizedBroadcastKernel<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
512,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
......@@ -420,7 +477,25 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs,
main_offset,
tail_tid,
read_lens,
func);
} else {
VectorizedBroadcastKernel<InT,
OutT,
Functor,
Arity,
NumOuts,
256,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
}
#else
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
......@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs,
main_offset,
tail_tid,
VecSize,
func);
#endif
}
......
......@@ -577,14 +577,16 @@ template <typename InT,
struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result);
OutT *result,
int read_lens);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
OutT *result,
int read_lens) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
result, args, func);
}
......@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
OutT *result,
int read_lens) {
kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
}
};
......@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
OutT *result,
int read_lens) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], func);
}
......@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
OutT *result,
int read_lens) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], func);
result, args[0], args[1], func, read_lens);
}
};
......@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
OutT *result,
int read_lens) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
......@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
}
};
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCallerBc {
__device__ __forceinline__ void operator()(
phi::Array<_ptr_ OutT *, NumOuts> outs,
ConditionalT<OutT, NumOuts> src[VecSize],
int block_offset,
int num,
int read_lens) {
OutT dst[NumOuts][VecSize];
#pragma unroll
for (int i = 0; i < read_lens; ++i) {
#pragma unroll
for (int j = 0; j < NumOuts; ++j) {
dst[j][i] = (src[i])[j];
}
}
#pragma unroll
for (int i = 0; i < NumOuts; ++i) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[i] + block_offset, dst[i], num, read_lens);
}
}
};
template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, 1> {
__device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs,
OutT src[VecSize],
int block_offset,
int num,
int read_lens) {
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
outs[0] + block_offset, src, num, read_lens);
}
};
template <typename OutT,
typename Functor,
int Arity,
......
......@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
......
......@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace phi {
namespace kps {
......@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(
OutT* out, const InT* in1, const InT* in2, OpFunc compute, int read_lens) {
for (int idx = 0; idx < read_lens; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
......
......@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
template <typename T, int NX>
__device__ __forceinline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
for (int i = 0; i < NX; i++) {
dst[i] = init_data;
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
......@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) { // blockDim.x * NX > num
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
dst[idx] = src[thread_offset + idx];
}
}
} else { // blockDim,x * NX < num
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
}
}
}
}
/**
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
......@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
}
}
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst,
T* __restrict__ src,
int num,
int read_lens) {
if (IsBoundary) {
int thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((thread_offset + idx) < num) {
dst[thread_offset + idx] = src[idx];
}
}
} else {
// Vector type
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;
int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
#pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[thread_offset + idx] = vec_temp[idx];
}
}
}
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
......@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
}
}
template <typename T,
int NX,
int NY,
int BlockSize,
int Rank,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
T* dst,
const T* __restrict__ src,
uint32_t block_offset,
details::BroadcastConfig<Rank> config,
int total_num_output,
int read_lens) {
uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t index_src = 0;
#pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t index_output = thread_offset + nx;
index_src = 0;
if (IsBoundary) {
if (index_output >= total_num_output) {
break;
}
}
#pragma unroll
for (int i = 0; i < Rank; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(index_output);
index_output = fast_divmoder.val[0];
index_src += fast_divmoder.val[1] * config.strides[i];
}
dst[nx] = src[index_src];
}
}
/**
* @brief Initialize register with data index.
*
......
......@@ -46,6 +46,7 @@
#define KPStream gpuStream_t
#define KPDevice phi::GPUContext
#define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册