未验证 提交 aa0c885a 编写于 作者: S shixingbo 提交者: GitHub

Optimized the performance of broadcast for kp XPU2 (#44091)

上级 1e6137b5
...@@ -558,6 +558,9 @@ struct VecSizeGetter { ...@@ -558,6 +558,9 @@ struct VecSizeGetter {
template <typename OutT, typename Functor> template <typename OutT, typename Functor>
int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins, int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
const std::vector<DenseTensor *> &outs) { const std::vector<DenseTensor *> &outs) {
#ifdef PADDLE_WITH_XPU_KP
int vec_size = 256;
#else
using Traits = paddle::platform::FunctionTraits<Functor>; using Traits = paddle::platform::FunctionTraits<Functor>;
using ArgsT = typename Traits::ArgsTuple; using ArgsT = typename Traits::ArgsTuple;
const int Arity = Traits::arity; const int Arity = Traits::arity;
...@@ -569,6 +572,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins, ...@@ -569,6 +572,7 @@ int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
vec_size = vec_size =
std::min<int>(vec_size, phi::GetVectorizedSize((*iter)->data<OutT>())); std::min<int>(vec_size, phi::GetVectorizedSize((*iter)->data<OutT>()));
} }
#endif
return vec_size; return vec_size;
} }
...@@ -784,7 +788,6 @@ template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize> ...@@ -784,7 +788,6 @@ template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
void LaunchElementwiseCudaKernel(const KPDevice &ctx, void LaunchElementwiseCudaKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins, const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs, std::vector<DenseTensor *> *outs,
int read_lens,
Functor func) { Functor func) {
// There are at least 1 output, but maybe 0 input (ins.size() == 0). // There are at least 1 output, but maybe 0 input (ins.size() == 0).
// For large tensor numel * sizeof(T) > 2^31, we must use int64_t as index // For large tensor numel * sizeof(T) > 2^31, we must use int64_t as index
...@@ -800,6 +803,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx, ...@@ -800,6 +803,7 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
int block_size = 64; int block_size = 64;
int grid_size = 8; int grid_size = 8;
int read_lens = kps::details::GetXpuReadLens(numel, block_size, grid_size);
auto stream = ctx.x_context()->xpu_stream; auto stream = ctx.x_context()->xpu_stream;
int64_t main_offset = int64_t main_offset =
(numel / (read_lens * block_size)) * read_lens * block_size; (numel / (read_lens * block_size)) * read_lens * block_size;
...@@ -853,32 +857,20 @@ void ElementwiseKernel(const KPDevice &ctx, ...@@ -853,32 +857,20 @@ void ElementwiseKernel(const KPDevice &ctx,
} }
} }
#ifdef PADDLE_WITH_XPU_KP
const int buf_size = 256;
int numel = (*outs)[0]->numel();
int block_size = 64;
int grid_size = 8;
int nthreads = block_size * grid_size;
int read_lens =
std::min(buf_size, kps::details::RoundUpDiv(numel, 32 * nthreads) * 32);
int vec_size = buf_size;
#else
// calculate the max vec_size for all ins and outs // calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForTensors<OutT, Functor>(ins, *outs); int vec_size = GetVectorizedSizeForTensors<OutT, Functor>(ins, *outs);
int read_lens = vec_size;
#endif
switch (vec_size) { switch (vec_size) {
case VecSizeL: case VecSizeL:
LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeL>( LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeL>(
ctx, ins, outs, read_lens, func); ctx, ins, outs, func);
break; break;
case VecSizeM: case VecSizeM:
LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeM>( LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeM>(
ctx, ins, outs, read_lens, func); ctx, ins, outs, func);
break; break;
case VecSizeS: case VecSizeS:
LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeS>( LaunchElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, VecSizeS>(
ctx, ins, outs, read_lens, func); ctx, ins, outs, func);
break; break;
default: { default: {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
......
...@@ -21,7 +21,17 @@ namespace phi { ...@@ -21,7 +21,17 @@ namespace phi {
namespace kps { namespace kps {
namespace details { namespace details {
int RoundUpDiv(int n, int k) { return (n + k - 1) / k; } static inline int RoundUpDiv(int n, int k) { return (n + k - 1) / k; }
static inline int GetXpuReadLens(int numel, int block_num, int grid_num) {
const int buf_size = 256;
int nthreads = block_num * grid_num;
if (numel / nthreads == 1) {
return numel / nthreads * 4;
}
int read_lens = std::min(buf_size, RoundUpDiv(numel, 32 * nthreads) * 32);
return read_lens;
}
enum class OptType { // Optimize type of calc after input shape compressed enum class OptType { // Optimize type of calc after input shape compressed
CanNotOptimize = -1, // can not optimize, broadcast first CanNotOptimize = -1, // can not optimize, broadcast first
...@@ -98,8 +108,10 @@ struct BroadcastConfig { ...@@ -98,8 +108,10 @@ struct BroadcastConfig {
strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1]; strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1];
} }
int numel_out = 1;
for (int i = 0; i < dim_size; i++) { for (int i = 0; i < dim_size; i++) {
dim_tmp[i] = in_dims[i]; dim_tmp[i] = in_dims[i];
numel_out = out_dims[i] * numel_out;
} }
kDims = dim_size; kDims = dim_size;
memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
...@@ -108,13 +120,25 @@ struct BroadcastConfig { ...@@ -108,13 +120,25 @@ struct BroadcastConfig {
cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims); cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims);
get_opt_type(); get_opt_type();
buf_len = get_buf_len(); buf_len = get_buf_len(numel_out);
int numel_x = 1;
int numel_y = 1;
for (int i = 0; i < dim_size; i++) {
numel_x = in_dims[i] * numel_x;
numel_y = y_in_dims[i] * numel_y;
}
if (numel_out == numel_x && numel_out == numel_y) {
buf_len = GetXpuReadLens(numel_out, 8, 64);
}
} }
int get_buf_len() { int get_buf_len(int numel) {
if (cmp_type == OptType::CanNotOptimize) { if (cmp_type == OptType::CanNotOptimize) {
return 256; return 256;
} }
if (cmp_type == OptType::N_1) {
return kps::details::GetXpuReadLens(numel, 8, 64);
}
int max_buf_len = 512; int max_buf_len = 512;
int buf_len = m / 16 * 16; int buf_len = m / 16 * 16;
if (buf_len == 0) { if (buf_len == 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册