未验证 提交 f5a3b427 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

softmax fwd: force vec size to 1 when dtype is float (#54183)

* softmax fwd: force vec size to 1 when dtype is float

* use 1024 as threshold to use cudnn
上级 44bd5927
...@@ -798,6 +798,7 @@ void SwitchWarpSoftmaxForward(const IndexType blocks, ...@@ -798,6 +798,7 @@ void SwitchWarpSoftmaxForward(const IndexType blocks,
SOFTMAX_WARP_FORWARD_CASE(7, AccT); SOFTMAX_WARP_FORWARD_CASE(7, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, AccT); SOFTMAX_WARP_FORWARD_CASE(8, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, AccT); SOFTMAX_WARP_FORWARD_CASE(9, AccT);
SOFTMAX_WARP_FORWARD_CASE(10, AccT);
default: default:
break; break;
} }
...@@ -836,6 +837,7 @@ void SwitchWarpSoftmaxBackward(const int blocks, ...@@ -836,6 +837,7 @@ void SwitchWarpSoftmaxBackward(const int blocks,
SOFTMAX_WARP_BACKWARD_CASE(7, AccT); SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
SOFTMAX_WARP_BACKWARD_CASE(8, AccT); SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
SOFTMAX_WARP_BACKWARD_CASE(9, AccT); SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
SOFTMAX_WARP_BACKWARD_CASE(10, AccT);
default: default:
break; break;
} }
...@@ -1262,7 +1264,7 @@ bool UseCudnnSoftmax(const GPUContext& ctx, ...@@ -1262,7 +1264,7 @@ bool UseCudnnSoftmax(const GPUContext& ctx,
#endif #endif
} }
} }
constexpr int max_dim = 512; constexpr int max_dim = 1024;
if (!cudnn_available || !last_dim || if (!cudnn_available || !last_dim ||
(softmax_dim <= max_dim && sizeof(T) <= 4) || (softmax_dim <= max_dim && sizeof(T) <= 4) ||
softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) { softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) {
...@@ -1311,27 +1313,7 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, ...@@ -1311,27 +1313,7 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
using T4 = typename VecT4<T>::Type; using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type; using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) { if (std::is_same<T, float>::value) {
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks, SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads, threads,
dev_ctx, dev_ctx,
...@@ -1341,6 +1323,38 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, ...@@ -1341,6 +1323,38 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
dim, dim,
dim, dim,
dim_log2); dim_log2);
} else {
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
} }
} else { } else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out); LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册