未验证 提交 250e254f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of log_softmax (#38992)

* Optimize performance of log_softmax

* delete unity build

* modify to phi

* fix

* fixfixfixfix

* fix

* fix

* fix

* fix

* simplify

* fix

* fix enforce
上级 02e80f59
...@@ -351,8 +351,17 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -351,8 +351,17 @@ __global__ void WarpSoftmaxForward(T* softmax,
VecT* softmax_v = VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]); reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>( if (LogMode) {
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i])); kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnaryLogFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor<AccT>());
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out_tmp[i][0][0],
&srcdata[i][0][0],
UnarySubFunctor<AccT>(std::log(sum[i])));
} else {
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
}
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>( kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1); &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
} }
...@@ -434,15 +443,25 @@ __global__ void WarpSoftmaxBackward(T* dst, ...@@ -434,15 +443,25 @@ __global__ void WarpSoftmaxBackward(T* dst,
AccT sum_tmp[kBatchSize][kLoopsV][kVSize]; AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]); AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]); AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>( if (LogMode) {
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>()); kps::Reduce<AccT,
kps::Reduce<AccT, kVItem,
kVItem, kBatchSize,
kBatchSize, 1,
1, kps::AddFunctor<AccT>,
kps::AddFunctor<AccT>, kps::details::ReduceMode::kLocalMode>(
kps::details::ReduceMode::kLocalMode>( &sum[0], &grad_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true); } else {
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
kps::Reduce<AccT,
kVItem,
kBatchSize,
1,
kps::AddFunctor<AccT>,
kps::details::ReduceMode::kLocalMode>(
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
}
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result to global memory // write result to global memory
...@@ -453,10 +472,23 @@ __global__ void WarpSoftmaxBackward(T* dst, ...@@ -453,10 +472,23 @@ __global__ void WarpSoftmaxBackward(T* dst,
if (i >= local_batches) break; if (i >= local_batches) break;
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]); AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]); AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>( if (LogMode) {
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i])); kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpMulFunctor<AccT>>(
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>( &out[i][0][0], &srcptr[0], ExpMulFunctor<AccT>(sum[i]));
&out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor<AccT>()); kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::SubFunctor<AccT>>(
&out_tmp[i][0][0],
&gradptr[0],
&out[i][0][0],
kps::SubFunctor<AccT>());
} else {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
&out_tmp[i][0][0],
&srcptr[0],
&out[i][0][0],
kps::MulFunctor<AccT>());
}
VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]); VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>( kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
...@@ -639,7 +671,8 @@ __global__ void NormalSoftmaxForward( ...@@ -639,7 +671,8 @@ __global__ void NormalSoftmaxForward(
template <typename T, template <typename T,
typename AccT, typename AccT,
template <typename, typename> class Functor> template <typename, typename> class Functor,
bool LogMode>
__global__ void NormalSoftmaxBackward(T* input_grad, __global__ void NormalSoftmaxBackward(T* input_grad,
const T* output_grad, const T* output_grad,
const T* output, const T* output,
...@@ -656,10 +689,17 @@ __global__ void NormalSoftmaxBackward(T* input_grad, ...@@ -656,10 +689,17 @@ __global__ void NormalSoftmaxBackward(T* input_grad,
// 1. reduce sum // 1. reduce sum
AccT sum = 0; AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { if (LogMode) {
int data_offset = grad_offset + mid_id * mid_stride; for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
sum += static_cast<AccT>(output_grad[data_offset]) * int data_offset = grad_offset + mid_id * mid_stride;
static_cast<AccT>(output[data_offset]); sum += static_cast<AccT>(output_grad[data_offset]);
}
} else {
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
sum += static_cast<AccT>(output_grad[data_offset]) *
static_cast<AccT>(output[data_offset]);
}
} }
if (blockDim.y > 1) { if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>( kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
...@@ -715,10 +755,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, ...@@ -715,10 +755,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
dim3 grid, block; dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) { if (LogMode) {
NormalSoftmaxBackward< NormalSoftmaxBackward<T,
T, AccT,
AccT, LogSoftmaxBackwardFunctor,
LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, input_grad_data,
output_grad_data, output_grad_data,
output_data, output_data,
...@@ -726,10 +766,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, ...@@ -726,10 +766,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
mid_dim, mid_dim,
low_dim); low_dim);
} else { } else {
NormalSoftmaxBackward< NormalSoftmaxBackward<T,
T, AccT,
AccT, SoftmaxBackwardFunctor,
SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, input_grad_data,
output_grad_data, output_grad_data,
output_data, output_data,
...@@ -864,6 +904,32 @@ static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) { ...@@ -864,6 +904,32 @@ static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) {
return false; return false;
} }
#if CUDNN_VERSION < 8100
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
const bool log_mode,
DenseTensor* out) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
#endif
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册