未验证 提交 aa0c885a 编写于 作者: S shixingbo 提交者: GitHub

Optimized the performance of broadcast for kp XPU2 (#44091)

上级 1e6137b5
......@@ -558,6 +558,9 @@ struct VecSizeGetter {
template <typename OutT, typename Functor>
int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
const std::vector<DenseTensor *> &outs) {
#ifdef PADDLE_WITH_XPU_KP
int vec_size = 256;
#else
using Traits = paddle::platform::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple;
const int Arity = Traits::arity;
......@@ -569,6 +572,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
vec_size =
std::min<int>(vec_size, phi::GetVectorizedSize((*iter)->data<OutT>()));
}
#endif
return vec_size;
}
......@@ -784,7 +788,6 @@ template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
void LaunchElementwiseCudaKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int read_lens,
Functor func) {
// There are at least 1 output, but maybe 0 input (ins.size() == 0).
// For large tensor numel * sizeof(T) > 2^31, we must use int64_t as index
......@@ -800,6 +803,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
#ifdef PADDLE_WITH_XPU_KP
int block_size = 64;
int grid_size = 8;
int read_lens = kps::details::GetXpuReadLens(numel, block_size, grid_size);
auto stream = ctx.x_context()->xpu_stream;
int64_t main_offset =
(numel / (read_lens * block_size)) * read_lens * block_size;
......@@ -853,32 +857,20 @@ void ElementwiseKernel(const KPDevice &ctx,
}
}
#ifdef PADDLE_WITH_XPU_KP
const int buf_size = 256;
int numel = (*outs)[0]->numel();
int block_size = 64;
int grid_size = 8;
int nthreads = block_size * grid_size;
int read_lens =
std::min(buf_size, kps::details::RoundUpDiv(numel, 32 * nthreads) * 32);
int vec_size = buf_size;
#else
// calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForTensors<OutT, Functor>(ins, *outs);
int read_lens = vec_size;
#endif
switch (vec_size) {
case VecSizeL:
LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeL>(
ctx, ins, outs, read_lens, func);
ctx, ins, outs, func);
break;
case VecSizeM:
LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeM>(
ctx, ins, outs, read_lens, func);
ctx, ins, outs, func);
break;
case VecSizeS:
LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeS>(
ctx, ins, outs, read_lens, func);
ctx, ins, outs, func);
break;
default: {
PADDLE_THROW(phi::errors::Unimplemented(
......
......@@ -21,7 +21,17 @@ namespace phi {
namespace kps {
namespace details {
int RoundUpDiv(int n, int k) { return (n + k - 1) / k; }
static inline int RoundUpDiv(int n, int k) { return (n + k - 1) / k; }
static inline int GetXpuReadLens(int numel, int block_num, int grid_num) {
const int buf_size = 256;
int nthreads = block_num * grid_num;
if (numel / nthreads == 1) {
return numel / nthreads * 4;
}
int read_lens = std::min(buf_size, RoundUpDiv(numel, 32 * nthreads) * 32);
return read_lens;
}
enum class OptType { // Optimize type of calc after input shape compressed
CanNotOptimize = -1, // can not optimize, broadcast first
......@@ -98,8 +108,10 @@ struct BroadcastConfig {
strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1];
}
int numel_out = 1;
for (int i = 0; i < dim_size; i++) {
dim_tmp[i] = in_dims[i];
numel_out = out_dims[i] * numel_out;
}
kDims = dim_size;
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
......@@ -108,13 +120,25 @@ struct BroadcastConfig {
cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims);
get_opt_type();
buf_len = get_buf_len();
buf_len = get_buf_len(numel_out);
int numel_x = 1;
int numel_y = 1;
for (int i = 0; i < dim_size; i++) {
numel_x = in_dims[i] * numel_x;
numel_y = y_in_dims[i] * numel_y;
}
if (numel_out == numel_x && numel_out == numel_y) {
buf_len = GetXpuReadLens(numel_out, 8, 64);
}
}
int get_buf_len() {
int get_buf_len(int numel) {
if (cmp_type == OptType::CanNotOptimize) {
return 256;
}
if (cmp_type == OptType::N_1) {
return kps::details::GetXpuReadLens(numel, 8, 64);
}
int max_buf_len = 512;
int buf_len = m / 16 * 16;
if (buf_len == 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册