未验证 提交 18aca3f5 编写于 作者: F Feng Xing 提交者: GitHub

format softmax forward (#37927)

上级 fdf62e1e
...@@ -222,15 +222,27 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -222,15 +222,27 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
idx_max_v[i] = idx_max / kVSize; idx_max_v[i] = idx_max / kVSize;
} }
// read data from global memory // data src
AccT srcdata[kBatchSize][kLoopsV][kVSize]; AccT srcdata[kBatchSize][kLoopsV][kVSize];
kps::Init<AccT, kStep>(&srcdata[0][0][0], kLowInf);
T src_tmp[kBatchSize][kLoopsV][kVSize]; T src_tmp[kBatchSize][kLoopsV][kVSize];
kps::Init<AccT, kStep>(&srcdata[0][0][0], kLowInf);
kps::Init<T, kStep>(&src_tmp[0][0][0], -std::numeric_limits<T>::infinity()); kps::Init<T, kStep>(&src_tmp[0][0][0], -std::numeric_limits<T>::infinity());
// data dst
T out_tmp[kBatchSize][kLoopsV][kVSize];
// max value
AccT max[kBatchSize];
kps::Init<AccT, kBatchSize>(&max[0], kLowInf);
// sum value
AccT sum[kBatchSize] = {0};
// read data from global memory
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
int ptr = (first_batch + i) * stride; const VecT* src_v =
const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]); reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&src_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&src_tmp[i][0][0]);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>( kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
...@@ -239,15 +251,12 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -239,15 +251,12 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
} }
// compute max // compute max
AccT max[kBatchSize];
kps::Init<AccT, kBatchSize>(&max[0], kLowInf);
kps::Reduce<AccT, kVItem, kBatchSize, 1, ReduceMaxFunctor<AccT>, kps::Reduce<AccT, kVItem, kBatchSize, 1, ReduceMaxFunctor<AccT>,
kMode::kLocalMode>(&max[0], &srcdata[0][0][0], kMode::kLocalMode>(&max[0], &srcdata[0][0][0],
ReduceMaxFunctor<AccT>(), true); ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max); WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
// compute sum // compute sum
AccT sum[kBatchSize] = {0};
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>( kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i])); &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
...@@ -257,15 +266,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -257,15 +266,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
kps::AddFunctor<AccT>(), true); kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result to global memory // write data to global memory
T out_tmp[kBatchSize][kLoopsV][kVSize];
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>( kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i])); &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
int softmax_ptr = (first_batch + i) * stride;
VecT* softmax_v = reinterpret_cast<VecT*>(&softmax[softmax_ptr]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>( kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1); &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册