未验证 提交 93e1bb98 编写于 作者: A Asthestarsfalll 提交者: GitHub

optimize logsumexp in small data scale (#52952)

* optimize logsumexp in small data scale

* fix

* fix

* add #pragma once

* swith to use aligned_vector and support arbitrarily shape

* fix store

* fix store

* refine for special cases

* try

* fix

* update

* fix

* fix all_reduce

* try

* fix rocm bug

* fix rocm bug

* fix rocm bug

* fix rocm bug

* fix rocm bug

* fix rocm bug

* fix rocm bug

* fix rocm bug
上级 c32a3002
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
//
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#define CUDART_INF __longlong_as_double(0x7ff0000000000000ULL)
#define CUDART_INF_F __int_as_float(0x7f800000)
namespace phi {
namespace funcs {
constexpr int kWarpSize = 32;
template <typename T>
__inline__ __device__ T Inf();
template <>
__inline__ __device__ float Inf<float>() {
return CUDART_INF_F;
}
template <>
__inline__ __device__ double Inf<double>() {
return CUDART_INF;
}
template <typename T,
template <typename>
class Functor,
int ThreadGroupWidth = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = ThreadGroupWidth / 2; mask > 0; mask /= 2) {
#if PADDLE_WITH_HIP
val = Functor<T>()(val, __shfl_xor(0xffffffff, val, mask));
#else
val = Functor<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
#endif
}
return val;
}
#if PADDLE_WITH_HIP
inline void GetNumBlocks(int64_t block_size,
int64_t max_blocks,
int64_t waves,
int* num_blocks) {
int dev;
PADDLE_ENFORCE_GPU_SUCCESS(hipGetDevice(&dev));
int sm_count;
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute(
&sm_count, hipDeviceAttributeMultiprocessorCount, dev));
int tpm;
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute(
&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev));
*num_blocks = std::max<int>(
1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
}
#else
inline void GetNumBlocks(int64_t block_size,
int64_t max_blocks,
int64_t waves,
int* num_blocks) {
int dev;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&dev));
int sm_count;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev));
int tpm;
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev));
*num_blocks = std::max<int>(
1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
}
#endif
template <typename T,
typename SourceType,
typename Context,
int VecSize,
int ColsPerThread,
int RowsPerThread,
int ThreadGroupWidth,
bool NeedPadding>
__global__ void LogsumexpWarpImpl(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
static_assert(ColsPerThread % VecSize == 0, "");
static_assert(ThreadGroupWidth <= kWarpSize, "");
static_assert(kWarpSize % ThreadGroupWidth == 0, "");
constexpr int num_read = ColsPerThread / VecSize;
assert(num_col <= ColsPerThread * ThreadGroupWidth);
const int group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_thread_group = gridDim.x * blockDim.y;
const int thread_id = threadIdx.x;
const int step = num_thread_group * RowsPerThread;
using LoadType = phi::AlignedVector<SourceType, VecSize>;
using StoreType = phi::AlignedVector<SourceType, RowsPerThread>;
LoadType load_vec;
StoreType store_vec;
T buffer[RowsPerThread][ColsPerThread];
for (int64_t cur_row = group_id * RowsPerThread; cur_row < num_row;
cur_row += step) {
T thread_max[RowsPerThread];
// Read data
#pragma unroll
for (int row_id = 0; row_id < RowsPerThread; row_id++) {
thread_max[row_id] = -Inf<T>();
T* row_buffer = buffer[row_id];
#pragma unroll
for (int read_id = 0; read_id < num_read; read_id++) {
const int offset = read_id * VecSize;
const int cur_col = (read_id * ThreadGroupWidth + thread_id) * VecSize;
if (!NeedPadding || cur_col < num_col) {
int64_t load_offset = ((cur_row + row_id) * num_col + cur_col);
phi::Load<SourceType, VecSize>(in + load_offset, &load_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
row_buffer[offset + i] = static_cast<T>(load_vec[i]);
thread_max[row_id] =
max(thread_max[row_id], row_buffer[offset + i]);
}
} else {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
row_buffer[offset + i] = -Inf<T>();
}
}
}
}
T warp_max[RowsPerThread];
// Get warp max
#pragma unroll
for (int row_id = 0; row_id < RowsPerThread; row_id++) {
warp_max[row_id] = WarpAllReduce<T, kps::MaxFunctor, ThreadGroupWidth>(
thread_max[row_id]);
}
T thread_sum[RowsPerThread];
// Calculate
#pragma unroll
for (int row_id = 0; row_id < RowsPerThread; row_id++) {
thread_sum[row_id] = 0;
T* row_buffer = buffer[row_id];
#pragma unroll
for (int i = 0; i < ColsPerThread; i++) {
thread_sum[row_id] += exp(row_buffer[i] - warp_max[row_id]);
}
}
// Get warp sum and write
#pragma unroll
for (int row_id = 0; row_id < RowsPerThread; row_id++) {
T res = log(WarpAllReduce<T, kps::AddFunctor, ThreadGroupWidth>(
thread_sum[row_id]));
store_vec[row_id] = static_cast<SourceType>(res + warp_max[row_id]);
}
if (thread_id == 0 && cur_row < num_row) {
phi::Store<SourceType, RowsPerThread>(store_vec,
out + group_id * RowsPerThread);
}
}
}
template <typename T,
typename SourceType,
typename Context,
int VecSize,
int ColsPerThread,
int RowsPerThread,
int ThreadGroupWidth,
bool NeedPadding>
#if PADDLE_WITH_HIP
inline hipError_t LaunchLogsumexpWarp(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#else
inline cudaError_t LaunchLogsumexpWarp(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#endif
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % ThreadGroupWidth == 0, "");
constexpr int thread_groups_per_block = block_size / ThreadGroupWidth;
dim3 block_dim(ThreadGroupWidth, thread_groups_per_block);
const int64_t num_blocks =
(num_row / RowsPerThread + thread_groups_per_block - 1) /
thread_groups_per_block;
int grid_dim_x;
{ GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x); }
LogsumexpWarpImpl<T,
SourceType,
Context,
VecSize,
ColsPerThread,
RowsPerThread,
ThreadGroupWidth,
NeedPadding>
<<<grid_dim_x, block_dim, 0, dev_ctx.stream()>>>(
dev_ctx, num_row, num_col, in, out);
#if PADDLE_WITH_HIP
return hipPeekAtLastError();
#else
return cudaPeekAtLastError();
#endif
}
template <typename T,
typename SourceType,
typename Context,
int VecSize,
int ColsPerThread,
int RowsPerThread,
int ThreadGroupWidth>
#if PADDLE_WITH_HIP
inline hipError_t DispatchLogsumexpWarpWithPadding(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#else
inline cudaError_t DispatchLogsumexpWarpWithPadding(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#endif
if (num_col == ColsPerThread * ThreadGroupWidth) {
return LaunchLogsumexpWarp<T,
SourceType,
Context,
VecSize,
ColsPerThread,
RowsPerThread,
ThreadGroupWidth,
false>(dev_ctx, num_row, num_col, in, out);
} else {
return LaunchLogsumexpWarp<T,
SourceType,
Context,
VecSize,
ColsPerThread,
RowsPerThread,
ThreadGroupWidth,
true>(dev_ctx, num_row, num_col, in, out);
}
}
template <typename T, typename SourceType, typename Context, int VecSize>
#if PADDLE_WITH_HIP
typename std::enable_if<VecSize == 1, hipError_t>::type
DispatchLogsumexpWarpCols(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#else
typename std::enable_if<VecSize == 1, cudaError_t>::type
DispatchLogsumexpWarpCols(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#endif
if (num_col <= 0) {
#if PADDLE_WITH_HIP
return hipErrorInvalidValue;
#else
return cudaErrorInvalidValue;
#endif
}
#define HANDLE_THREAD_GROUP(thread_group_width) \
if (num_col <= (thread_group_width)*VecSize) { \
if (num_row % 2 == 0) { \
return DispatchLogsumexpWarpWithPadding<T, \
SourceType, \
Context, \
VecSize, \
VecSize, \
2, \
thread_group_width>( \
dev_ctx, num_row, num_col, in, out); \
} else { \
return DispatchLogsumexpWarpWithPadding<T, \
SourceType, \
Context, \
VecSize, \
VecSize, \
1, \
thread_group_width>( \
dev_ctx, num_row, num_col, in, out); \
} \
}
HANDLE_THREAD_GROUP(1)
HANDLE_THREAD_GROUP(2)
HANDLE_THREAD_GROUP(4)
HANDLE_THREAD_GROUP(8)
HANDLE_THREAD_GROUP(16)
HANDLE_THREAD_GROUP(32)
#undef HANDLE_ROWS
// if num_col > 32
#define HANDLE_COL(col) \
if (num_col <= (col)*kWarpSize) { \
return DispatchLogsumexpWarpWithPadding<T, \
SourceType, \
Context, \
VecSize, \
col, \
1, \
kWarpSize>( \
dev_ctx, num_row, num_col, in, out); \
}
HANDLE_COL(2)
HANDLE_COL(3)
HANDLE_COL(4)
HANDLE_COL(5)
HANDLE_COL(6)
HANDLE_COL(7)
HANDLE_COL(8)
HANDLE_COL(9)
HANDLE_COL(10)
HANDLE_COL(11)
HANDLE_COL(12)
HANDLE_COL(13)
HANDLE_COL(14)
HANDLE_COL(15)
HANDLE_COL(16)
HANDLE_COL(17)
HANDLE_COL(18)
HANDLE_COL(19)
HANDLE_COL(20)
HANDLE_COL(21)
HANDLE_COL(22)
HANDLE_COL(23)
HANDLE_COL(24)
HANDLE_COL(25)
HANDLE_COL(26)
HANDLE_COL(27)
HANDLE_COL(28)
HANDLE_COL(29)
HANDLE_COL(30)
HANDLE_COL(31)
HANDLE_COL(32)
#undef HANDLE_COL
#if PADDLE_WITH_HIP
return hipErrorInvalidValue;
#else
return cudaErrorInvalidValue;
#endif
}
template <typename T, typename SourceType, typename Context, int VecSize>
#if PADDLE_WITH_HIP
typename std::enable_if<VecSize == 2, hipError_t>::type
DispatchLogsumexpWarpCols(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#else
typename std::enable_if<VecSize == 2, cudaError_t>::type
DispatchLogsumexpWarpCols(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#endif
if (num_col <= 0) {
#if PADDLE_WITH_HIP
return hipErrorInvalidValue;
#else
return cudaErrorInvalidValue;
#endif
}
#define HANDLE_THREAD_GROUP(thread_group_width) \
if (num_col <= (thread_group_width)*VecSize) { \
if (num_row % 2 == 0) { \
return DispatchLogsumexpWarpWithPadding<T, \
SourceType, \
Context, \
VecSize, \
VecSize, \
2, \
thread_group_width>( \
dev_ctx, num_row, num_col, in, out); \
} else { \
return DispatchLogsumexpWarpWithPadding<T, \
SourceType, \
Context, \
VecSize, \
VecSize, \
1, \
thread_group_width>( \
dev_ctx, num_row, num_col, in, out); \
} \
}
HANDLE_THREAD_GROUP(1)
HANDLE_THREAD_GROUP(2)
HANDLE_THREAD_GROUP(4)
HANDLE_THREAD_GROUP(8)
HANDLE_THREAD_GROUP(16)
HANDLE_THREAD_GROUP(32)
#undef HANDLE_THREAD_GROUP
// if num_col > 32
#define HANDLE_COL(col) \
if (num_col <= (col)*kWarpSize) { \
return DispatchLogsumexpWarpWithPadding<T, \
SourceType, \
Context, \
VecSize, \
col, \
1, \
kWarpSize>( \
dev_ctx, num_row, num_col, in, out); \
}
HANDLE_COL(4)
HANDLE_COL(6)
HANDLE_COL(8)
HANDLE_COL(10)
HANDLE_COL(12)
HANDLE_COL(14)
HANDLE_COL(16)
HANDLE_COL(18)
HANDLE_COL(20)
HANDLE_COL(22)
HANDLE_COL(24)
HANDLE_COL(26)
HANDLE_COL(28)
HANDLE_COL(30)
HANDLE_COL(32)
#undef HANDLE_COL
#if PADDLE_WITH_HIP
return hipErrorInvalidValue;
#else
return cudaErrorInvalidValue;
#endif
}
template <typename T, typename SourceType, typename Context>
#if PADDLE_WITH_HIP
inline hipError_t DispatchLogsumexpWarp(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#else
inline cudaError_t DispatchLogsumexpWarp(const Context& dev_ctx,
const int64_t num_row,
const int64_t num_col,
const SourceType* in,
SourceType* out) {
#endif
// dispatch logsumexp warp with vecsize
if (num_col % 2 == 0) {
return DispatchLogsumexpWarpCols<T, SourceType, Context, 2>(
dev_ctx, num_row, num_col, in, out);
} else {
return DispatchLogsumexpWarpCols<T, SourceType, Context, 1>(
dev_ctx, num_row, num_col, in, out);
}
}
} // namespace funcs
} // namespace phi
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/logsumexp_kernel.h"
#include "paddle/phi/kernels/gpu/logsumexp_function.cu.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
......@@ -21,10 +22,26 @@
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
template <typename T>
struct ComputeType {
using type = T;
};
template <>
struct ComputeType<phi::dtype::float16> {
using type = float;
};
template <>
struct ComputeType<phi::dtype::bfloat16> {
using type = float;
};
template <typename T>
struct LogCUDAFunctor {
HOSTDEVICE inline T operator()(const T x) const { return std::log(x); }
......@@ -46,6 +63,44 @@ struct LogCUDAFunctor<bfloat16> {
}
};
template <typename T, typename Context>
void LogsumexpFallbackKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis_vec,
const std::vector<int64_t>& outdim_vec,
const std::vector<int64_t>& keeped_outdim_vec,
bool keepdim,
bool reduce_all,
DenseTensor* out) {
auto* in_x = &x;
auto* out_y = out;
auto outdim = phi::make_ddim(outdim_vec);
auto keeped_outdim = phi::make_ddim(keeped_outdim_vec);
out->Resize(outdim);
dev_ctx.template Alloc<T>(out_y);
DenseTensor max_x;
max_x.Resize(outdim);
dev_ctx.template Alloc<T>(&max_x);
phi::funcs::ReduceKernel<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
dev_ctx, *in_x, &max_x, kps::IdentityFunctor<T>(), axis_vec);
max_x.Resize(keeped_outdim);
DenseTensor temp_x = Subtract<T, Context>(dev_ctx, *in_x, max_x);
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
dev_ctx, temp_x, out_y, kps::ExpFunctor<T>(), axis_vec);
const std::vector<const DenseTensor*> inputs = {out_y};
std::vector<DenseTensor*> outputs = {&temp_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, LogCUDAFunctor<T>());
temp_x.Resize(outdim);
out->Resize(outdim);
phi::AddKernel<T, Context>(dev_ctx, temp_x, max_x, out);
}
template <typename T, typename Context>
void LogsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -53,9 +108,7 @@ void LogsumexpKernel(const Context& dev_ctx,
bool keepdim,
bool reduce_all,
DenseTensor* out) {
auto* in_x = &x;
auto* out_y = out;
auto xdim = in_x->dims();
auto xdim = x.dims();
for (size_t i = 0; i < xdim.size(); i++)
PADDLE_ENFORCE_LT(0,
xdim[i],
......@@ -63,13 +116,15 @@ void LogsumexpKernel(const Context& dev_ctx,
"The dims of Input(X) should be greater than 0."));
reduce_all = recompute_reduce_all(x, axis, reduce_all);
std::vector<int64_t> outdim_vec, keeped_outdim_vec;
std::vector<int> axis_vec;
std::vector<int64_t> outdim_vec, keeped_outdim_vec, transpose_shape;
std::vector<int> axis_vec, perm;
int64_t compute_size = 1, other_size = 1;
for (auto i : axis) {
auto v = i >= 0 ? i : i + xdim.size();
axis_vec.push_back(v);
}
if (axis.size() == 0 || reduce_all) {
axis_vec.clear();
for (size_t i = 0; i < xdim.size(); i++) {
axis_vec.push_back(i);
}
......@@ -83,38 +138,48 @@ void LogsumexpKernel(const Context& dev_ctx,
}
}
if (flag) {
compute_size *= xdim[i];
keeped_outdim_vec.push_back(1);
if (keepdim) outdim_vec.push_back(1);
} else {
other_size *= xdim[i];
transpose_shape.push_back(xdim[i]);
perm.push_back(i);
outdim_vec.push_back(xdim[i]);
keeped_outdim_vec.push_back(xdim[i]);
}
}
auto outdim = phi::make_ddim(outdim_vec);
auto keeped_outdim = phi::make_ddim(keeped_outdim_vec);
out->Resize(outdim);
dev_ctx.template Alloc<T>(out_y);
DenseTensor max_x;
max_x.Resize(outdim);
dev_ctx.template Alloc<T>(&max_x);
phi::funcs::ReduceKernel<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
dev_ctx, *in_x, &max_x, kps::IdentityFunctor<T>(), axis_vec);
max_x.Resize(keeped_outdim);
DenseTensor temp_x = Subtract<T, Context>(dev_ctx, *in_x, max_x);
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
dev_ctx, temp_x, out_y, kps::ExpFunctor<T>(), axis_vec);
const std::vector<const DenseTensor*> inputs = {out_y};
std::vector<DenseTensor*> outputs = {&temp_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, LogCUDAFunctor<T>());
temp_x.Resize(outdim);
out->Resize(outdim);
phi::AddKernel<T, Context>(dev_ctx, temp_x, max_x, out);
if (compute_size <= 1024) {
if (perm.size() != xdim.size())
perm.insert(perm.end(), axis_vec.begin(), axis_vec.end());
for (auto i : axis_vec) transpose_shape.push_back(xdim[i]);
DenseTensor transpose_x;
if (xdim.size() == 0 ||
(axis_vec.size() == 1 && axis_vec[0] == xdim.size())) {
transpose_x = x;
} else {
transpose_x.Resize(make_ddim(transpose_shape));
dev_ctx.template Alloc<T>(&transpose_x);
phi::funcs::TransposeGPUKernelDriver<T>(dev_ctx, x, perm, &transpose_x);
}
dev_ctx.template Alloc<T>(out);
using compute_type = typename ComputeType<T>::type;
const int64_t num_col = compute_size, num_row = other_size;
funcs::DispatchLogsumexpWarp<compute_type, T, Context>(
dev_ctx, num_row, num_col, transpose_x.data<T>(), out->data<T>());
out->Resize(outdim);
} else {
LogsumexpFallbackKernel<T, Context>(dev_ctx,
x,
axis_vec,
outdim_vec,
keeped_outdim_vec,
keepdim,
reduce_all,
out);
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册