未验证 提交 71748805 编写于 作者: C carryyu 提交者: GitHub

fix softmax memory align (#46902)

上级 cf9ca61d
......@@ -346,28 +346,41 @@ template <template <typename, typename> class Reduction,
typename AccT,
int VecSize>
__device__ __forceinline__ AccT
ThreadVecReduce(const T* data,
ThreadVecReduce(T* data,
int dim_size,
const int shift,
const Reduction<T, AccT>& functor,
AccT default_value) {
using VecT = phi::AlignedVector<T, VecSize>;
AccT thread_val = default_value;
// for memory align, handle the unaligned data in first block.
int offset = threadIdx.x;
if (shift > 0) {
data -= shift;
dim_size += shift;
if (offset >= shift) {
thread_val = functor(thread_val, data[offset]);
}
dim_size -= blockDim.x;
data += blockDim.x;
}
const int last = dim_size % (VecSize * blockDim.x);
T v[VecSize];
VecT* value = reinterpret_cast<VecT*>(&v);
for (int offset = threadIdx.x; offset * VecSize < dim_size - last;
offset += blockDim.x) {
*value = reinterpret_cast<VecT*>(const_cast<T*>(data))[offset];
for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
*value = reinterpret_cast<VecT*>(data)[offset];
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_val = functor(thread_val, v[i]);
}
}
for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
offset += blockDim.x) {
offset = dim_size - last + threadIdx.x;
for (; offset < dim_size; offset += blockDim.x) {
thread_val = functor(thread_val, data[offset]);
}
return thread_val;
......@@ -377,12 +390,27 @@ template <template <typename, typename> class Reduction,
typename T,
typename AccT,
int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out,
const T* input,
int dim_size,
Reduction<AccT, T> functor) {
__device__ __forceinline__ void ThreadVecWriteVec(T* out,
T* input,
int dim_size,
const int shift,
Reduction<AccT, T> functor) {
using VecT = phi::AlignedVector<T, VecSize>;
// for memory align, handle the unaligned data in first block.
int offset = threadIdx.x;
if (shift > 0) {
input -= shift;
out -= shift;
dim_size += shift;
if (offset >= shift) {
out[offset] = functor(static_cast<AccT>(input[offset]));
}
dim_size -= blockDim.x;
input += blockDim.x;
out += blockDim.x;
}
const int last = dim_size % (VecSize * blockDim.x);
T in_v[VecSize];
......@@ -391,9 +419,8 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
T out_v[VecSize];
VecT* out_value = reinterpret_cast<VecT*>(&out_v);
for (int offset = threadIdx.x; offset * VecSize < dim_size - last;
offset += blockDim.x) {
*in_value = reinterpret_cast<VecT*>(const_cast<T*>(input))[offset];
for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
*in_value = reinterpret_cast<VecT*>(input)[offset];
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_v[i] = functor(static_cast<AccT>(in_v[i]));
......@@ -401,6 +428,33 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
reinterpret_cast<VecT*>(out)[offset] = *out_value;
}
offset = dim_size - last + threadIdx.x;
// the tail
for (; offset < dim_size; offset += blockDim.x) {
out[offset] = functor(static_cast<AccT>(input[offset]));
}
}
template <template <typename, typename> class Reduction,
typename T,
typename AccT,
int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out,
T* input,
int dim_size,
Reduction<AccT, T> functor) {
const int last = dim_size % (VecSize * blockDim.x);
for (int offset = threadIdx.x; offset < dim_size - last;
offset += blockDim.x * VecSize) {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out[offset + i * blockDim.x] =
functor(static_cast<AccT>(input[offset + i * blockDim.x]));
}
}
// the tail
for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
offset += blockDim.x) {
out[offset] = functor(static_cast<AccT>(input[offset]));
......@@ -417,13 +471,19 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
using VecT = phi::AlignedVector<T, VecSize>;
int bid = blockIdx.x;
const T* batch_input = src + bid * dim_size;
T* batch_input = const_cast<T*>(src) + bid * dim_size;
T* batch_output = softmax + bid * dim_size;
const int input_align_shift =
((uint64_t)batch_input) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
const int output_align_shift =
((uint64_t)batch_output) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
// get max value
AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>(
batch_input,
dim_size,
input_align_shift,
MaxFunctor<T, AccT>(),
std::numeric_limits<AccT>::min());
BlockReduceMax<AccT>(&thread_max);
......@@ -432,6 +492,7 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>(
batch_input,
dim_size,
input_align_shift,
SumExpFunctor<T, AccT>(thread_max),
static_cast<AccT>(0.));
BlockReduceSum<AccT>(&thread_exp);
......@@ -440,12 +501,22 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
if (LogMode) {
LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max,
std::log(thread_exp));
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
if (input_align_shift == output_align_shift) {
ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
} else {
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
}
} else {
SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
if (input_align_shift == output_align_shift) {
ThreadVecWriteVec<SoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
} else {
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
}
}
}
......@@ -1371,5 +1442,9 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
}
}
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
#undef FIXED_VEC_SIZE_BASE
#undef FIXED_VEC_SIZE
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册