未验证 提交 2dec25db 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize the computation of log_softmax (#40612)

* Optimize the computation of log_softmax

* modify the var name
上级 a09a93a1
...@@ -121,17 +121,10 @@ struct ReduceMaxFunctor { ...@@ -121,17 +121,10 @@ struct ReduceMaxFunctor {
}; };
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct ExpSubFunctor { struct ExpFunctor {
HOSTDEVICE inline ExpSubFunctor() { y = static_cast<Tx>(0.0f); }
HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x - y)); return static_cast<Ty>(std::exp(x));
} }
private:
Tx y;
}; };
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
...@@ -293,10 +286,14 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -293,10 +286,14 @@ __global__ void WarpSoftmaxForward(T* softmax,
} }
// data src // data src
AccT srcdata[kBatchSize][kLoopsV][kVSize]; // src_data: the raw data form global memory
T src_tmp[kBatchSize][kLoopsV][kVSize]; // sub_data: store the data obtained by (src_data - max), used by log_softmax
kps::Init<AccT, kStep>(&srcdata[0][0][0], kLowInf); // exp_data: store the data obtained by (exp(sub_data)), used by softmax
kps::Init<T, kStep>(&src_tmp[0][0][0], -std::numeric_limits<T>::infinity()); T src_data[kBatchSize][kLoopsV][kVSize];
AccT sub_data[kBatchSize][kLoopsV][kVSize];
AccT exp_data[kBatchSize][kLoopsV][kVSize];
kps::Init<AccT, kStep>(&sub_data[0][0][0], kLowInf);
kps::Init<T, kStep>(&src_data[0][0][0], -std::numeric_limits<T>::infinity());
// data dst // data dst
T out_tmp[kBatchSize][kLoopsV][kVSize]; T out_tmp[kBatchSize][kLoopsV][kVSize];
...@@ -313,11 +310,11 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -313,11 +310,11 @@ __global__ void WarpSoftmaxForward(T* softmax,
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
const VecT* src_v = const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]); reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&src_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>( kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
kps::ElementwiseUnary<T, AccT, kVItem, 1, 1, DataTransFunctor<T, AccT>>( kps::ElementwiseUnary<T, AccT, kVItem, 1, 1, DataTransFunctor<T, AccT>>(
&srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor<T, AccT>()); &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor<T, AccT>());
} }
// compute max // compute max
...@@ -327,14 +324,16 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -327,14 +324,16 @@ __global__ void WarpSoftmaxForward(T* softmax,
1, 1,
ReduceMaxFunctor<AccT>, ReduceMaxFunctor<AccT>,
kMode::kLocalMode>( kMode::kLocalMode>(
&max[0], &srcdata[0][0][0], ReduceMaxFunctor<AccT>(), true); &max[0], &sub_data[0][0][0], ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max); WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
// compute sum // compute sum
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>( kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i])); &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i]));
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpFunctor<AccT>>(
&exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor<AccT>());
} }
kps::Reduce<AccT, kps::Reduce<AccT,
kVItem, kVItem,
...@@ -342,7 +341,7 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -342,7 +341,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
1, 1,
kps::AddFunctor<AccT>, kps::AddFunctor<AccT>,
kMode::kLocalMode>( kMode::kLocalMode>(
&sum[0], &srcdata[0][0][0], kps::AddFunctor<AccT>(), true); &sum[0], &exp_data[0][0][0], kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write data to global memory // write data to global memory
...@@ -352,15 +351,13 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -352,15 +351,13 @@ __global__ void WarpSoftmaxForward(T* softmax,
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]); reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
if (LogMode) { if (LogMode) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnaryLogFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor<AccT>());
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnarySubFunctor<AccT>>( kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out_tmp[i][0][0], &out_tmp[i][0][0],
&srcdata[i][0][0], &sub_data[i][0][0],
UnarySubFunctor<AccT>(std::log(sum[i]))); UnarySubFunctor<AccT>(std::log(sum[i])));
} else { } else {
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>( kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i])); &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
} }
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>( kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1); &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册