未验证 提交 250e254f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of log_softmax (#38992)

* Optimize performance of log_softmax

* delete unity build

* modify to phi

* fix

* fixfixfixfix

* fix

* fix

* fix

* fix

* simplify

* fix

* fix enforce
上级 02e80f59
......@@ -351,8 +351,17 @@ __global__ void WarpSoftmaxForward(T* softmax,
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
if (LogMode) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnaryLogFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor<AccT>());
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out_tmp[i][0][0],
&srcdata[i][0][0],
UnarySubFunctor<AccT>(std::log(sum[i])));
} else {
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
}
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
}
......@@ -434,6 +443,15 @@ __global__ void WarpSoftmaxBackward(T* dst,
AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
if (LogMode) {
kps::Reduce<AccT,
kVItem,
kBatchSize,
1,
kps::AddFunctor<AccT>,
kps::details::ReduceMode::kLocalMode>(
&sum[0], &grad_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
} else {
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
kps::Reduce<AccT,
......@@ -443,6 +461,7 @@ __global__ void WarpSoftmaxBackward(T* dst,
kps::AddFunctor<AccT>,
kps::details::ReduceMode::kLocalMode>(
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
}
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result to global memory
......@@ -453,10 +472,23 @@ __global__ void WarpSoftmaxBackward(T* dst,
if (i >= local_batches) break;
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
if (LogMode) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpMulFunctor<AccT>>(
&out[i][0][0], &srcptr[0], ExpMulFunctor<AccT>(sum[i]));
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::SubFunctor<AccT>>(
&out_tmp[i][0][0],
&gradptr[0],
&out[i][0][0],
kps::SubFunctor<AccT>());
} else {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
&out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor<AccT>());
&out_tmp[i][0][0],
&srcptr[0],
&out[i][0][0],
kps::MulFunctor<AccT>());
}
VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
......@@ -639,7 +671,8 @@ __global__ void NormalSoftmaxForward(
template <typename T,
typename AccT,
template <typename, typename> class Functor>
template <typename, typename> class Functor,
bool LogMode>
__global__ void NormalSoftmaxBackward(T* input_grad,
const T* output_grad,
const T* output,
......@@ -656,11 +689,18 @@ __global__ void NormalSoftmaxBackward(T* input_grad,
// 1. reduce sum
AccT sum = 0;
if (LogMode) {
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
sum += static_cast<AccT>(output_grad[data_offset]);
}
} else {
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
sum += static_cast<AccT>(output_grad[data_offset]) *
static_cast<AccT>(output[data_offset]);
}
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
&sum, &sum, kps::AddFunctor<AccT>(), false);
......@@ -715,10 +755,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) {
NormalSoftmaxBackward<
T,
NormalSoftmaxBackward<T,
AccT,
LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
LogSoftmaxBackwardFunctor,
LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data,
output_grad_data,
output_data,
......@@ -726,10 +766,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
mid_dim,
low_dim);
} else {
NormalSoftmaxBackward<
T,
NormalSoftmaxBackward<T,
AccT,
SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
SoftmaxBackwardFunctor,
LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data,
output_grad_data,
output_data,
......@@ -864,6 +904,32 @@ static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) {
return false;
}
#if CUDNN_VERSION < 8100
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
const bool log_mode,
DenseTensor* out) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
#endif
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册