未验证 提交 b5dd12fb 编写于 作者: Y Yanxing Shi 提交者: GitHub

fix softmax max dim (#37901)

上级 a8f009e4
...@@ -449,7 +449,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -449,7 +449,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const int N = SizeToAxis(axis, dims); const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims); const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320; constexpr int max_dim = 512;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
...@@ -540,7 +540,7 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -540,7 +540,7 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const int N = SizeToAxis(axis, dims); const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims); const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320; constexpr int max_dim = 512;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册