未验证 提交 f5a3b427 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

softmax fwd: force vec size to 1 when dtype is float (#54183)

* softmax fwd: force vec size to 1 when dtype is float

* use 1024 as threshold to use cudnn
上级 44bd5927
......@@ -798,6 +798,7 @@ void SwitchWarpSoftmaxForward(const IndexType blocks,
SOFTMAX_WARP_FORWARD_CASE(7, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, AccT);
SOFTMAX_WARP_FORWARD_CASE(10, AccT);
default:
break;
}
......@@ -836,6 +837,7 @@ void SwitchWarpSoftmaxBackward(const int blocks,
SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
SOFTMAX_WARP_BACKWARD_CASE(10, AccT);
default:
break;
}
......@@ -1262,7 +1264,7 @@ bool UseCudnnSoftmax(const GPUContext& ctx,
#endif
}
}
constexpr int max_dim = 512;
constexpr int max_dim = 1024;
if (!cudnn_available || !last_dim ||
(softmax_dim <= max_dim && sizeof(T) <= 4) ||
softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) {
......@@ -1311,27 +1313,7 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
if (std::is_same<T, float>::value) {
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads,
dev_ctx,
......@@ -1341,6 +1323,38 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
dim,
dim,
dim_log2);
} else {
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
}
} else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册