未验证 提交 71748805 编写于 作者: C carryyu 提交者: GitHub

fix softmax memory align (#46902)

上级 cf9ca61d
...@@ -346,28 +346,41 @@ template <template <typename, typename> class Reduction, ...@@ -346,28 +346,41 @@ template <template <typename, typename> class Reduction,
typename AccT, typename AccT,
int VecSize> int VecSize>
__device__ __forceinline__ AccT __device__ __forceinline__ AccT
ThreadVecReduce(const T* data, ThreadVecReduce(T* data,
int dim_size, int dim_size,
const int shift,
const Reduction<T, AccT>& functor, const Reduction<T, AccT>& functor,
AccT default_value) { AccT default_value) {
using VecT = phi::AlignedVector<T, VecSize>; using VecT = phi::AlignedVector<T, VecSize>;
AccT thread_val = default_value; AccT thread_val = default_value;
// for memory align, handle the unaligned data in first block.
int offset = threadIdx.x;
if (shift > 0) {
data -= shift;
dim_size += shift;
if (offset >= shift) {
thread_val = functor(thread_val, data[offset]);
}
dim_size -= blockDim.x;
data += blockDim.x;
}
const int last = dim_size % (VecSize * blockDim.x); const int last = dim_size % (VecSize * blockDim.x);
T v[VecSize]; T v[VecSize];
VecT* value = reinterpret_cast<VecT*>(&v); VecT* value = reinterpret_cast<VecT*>(&v);
for (int offset = threadIdx.x; offset * VecSize < dim_size - last; for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
offset += blockDim.x) { *value = reinterpret_cast<VecT*>(data)[offset];
*value = reinterpret_cast<VecT*>(const_cast<T*>(data))[offset];
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
thread_val = functor(thread_val, v[i]); thread_val = functor(thread_val, v[i]);
} }
} }
for (int offset = dim_size - last + threadIdx.x; offset < dim_size; offset = dim_size - last + threadIdx.x;
offset += blockDim.x) { for (; offset < dim_size; offset += blockDim.x) {
thread_val = functor(thread_val, data[offset]); thread_val = functor(thread_val, data[offset]);
} }
return thread_val; return thread_val;
...@@ -377,12 +390,27 @@ template <template <typename, typename> class Reduction, ...@@ -377,12 +390,27 @@ template <template <typename, typename> class Reduction,
typename T, typename T,
typename AccT, typename AccT,
int VecSize> int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out, __device__ __forceinline__ void ThreadVecWriteVec(T* out,
const T* input, T* input,
int dim_size, int dim_size,
Reduction<AccT, T> functor) { const int shift,
Reduction<AccT, T> functor) {
using VecT = phi::AlignedVector<T, VecSize>; using VecT = phi::AlignedVector<T, VecSize>;
// for memory align, handle the unaligned data in first block.
int offset = threadIdx.x;
if (shift > 0) {
input -= shift;
out -= shift;
dim_size += shift;
if (offset >= shift) {
out[offset] = functor(static_cast<AccT>(input[offset]));
}
dim_size -= blockDim.x;
input += blockDim.x;
out += blockDim.x;
}
const int last = dim_size % (VecSize * blockDim.x); const int last = dim_size % (VecSize * blockDim.x);
T in_v[VecSize]; T in_v[VecSize];
...@@ -391,9 +419,8 @@ __device__ __forceinline__ void ThreadVecWrite(T* out, ...@@ -391,9 +419,8 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
T out_v[VecSize]; T out_v[VecSize];
VecT* out_value = reinterpret_cast<VecT*>(&out_v); VecT* out_value = reinterpret_cast<VecT*>(&out_v);
for (int offset = threadIdx.x; offset * VecSize < dim_size - last; for (; offset * VecSize < dim_size - last; offset += blockDim.x) {
offset += blockDim.x) { *in_value = reinterpret_cast<VecT*>(input)[offset];
*in_value = reinterpret_cast<VecT*>(const_cast<T*>(input))[offset];
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
out_v[i] = functor(static_cast<AccT>(in_v[i])); out_v[i] = functor(static_cast<AccT>(in_v[i]));
...@@ -401,6 +428,33 @@ __device__ __forceinline__ void ThreadVecWrite(T* out, ...@@ -401,6 +428,33 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
reinterpret_cast<VecT*>(out)[offset] = *out_value; reinterpret_cast<VecT*>(out)[offset] = *out_value;
} }
offset = dim_size - last + threadIdx.x;
// the tail
for (; offset < dim_size; offset += blockDim.x) {
out[offset] = functor(static_cast<AccT>(input[offset]));
}
}
template <template <typename, typename> class Reduction,
typename T,
typename AccT,
int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out,
T* input,
int dim_size,
Reduction<AccT, T> functor) {
const int last = dim_size % (VecSize * blockDim.x);
for (int offset = threadIdx.x; offset < dim_size - last;
offset += blockDim.x * VecSize) {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out[offset + i * blockDim.x] =
functor(static_cast<AccT>(input[offset + i * blockDim.x]));
}
}
// the tail
for (int offset = dim_size - last + threadIdx.x; offset < dim_size; for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
offset += blockDim.x) { offset += blockDim.x) {
out[offset] = functor(static_cast<AccT>(input[offset])); out[offset] = functor(static_cast<AccT>(input[offset]));
...@@ -417,13 +471,19 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) { ...@@ -417,13 +471,19 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
using VecT = phi::AlignedVector<T, VecSize>; using VecT = phi::AlignedVector<T, VecSize>;
int bid = blockIdx.x; int bid = blockIdx.x;
const T* batch_input = src + bid * dim_size; T* batch_input = const_cast<T*>(src) + bid * dim_size;
T* batch_output = softmax + bid * dim_size; T* batch_output = softmax + bid * dim_size;
const int input_align_shift =
((uint64_t)batch_input) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
const int output_align_shift =
((uint64_t)batch_output) % MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
// get max value // get max value
AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>( AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>(
batch_input, batch_input,
dim_size, dim_size,
input_align_shift,
MaxFunctor<T, AccT>(), MaxFunctor<T, AccT>(),
std::numeric_limits<AccT>::min()); std::numeric_limits<AccT>::min());
BlockReduceMax<AccT>(&thread_max); BlockReduceMax<AccT>(&thread_max);
...@@ -432,6 +492,7 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) { ...@@ -432,6 +492,7 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>( AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>(
batch_input, batch_input,
dim_size, dim_size,
input_align_shift,
SumExpFunctor<T, AccT>(thread_max), SumExpFunctor<T, AccT>(thread_max),
static_cast<AccT>(0.)); static_cast<AccT>(0.));
BlockReduceSum<AccT>(&thread_exp); BlockReduceSum<AccT>(&thread_exp);
...@@ -440,12 +501,22 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) { ...@@ -440,12 +501,22 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
if (LogMode) { if (LogMode) {
LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max, LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max,
std::log(thread_exp)); std::log(thread_exp));
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>( if (input_align_shift == output_align_shift) {
batch_output, batch_input, dim_size, reduction); ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
} else {
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
}
} else { } else {
SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp); SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>( if (input_align_shift == output_align_shift) {
batch_output, batch_input, dim_size, reduction); ThreadVecWriteVec<SoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
} else {
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
}
} }
} }
...@@ -1371,5 +1442,9 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1371,5 +1442,9 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D); dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} }
} }
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
#undef FIXED_VEC_SIZE_BASE
#undef FIXED_VEC_SIZE
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册