未验证 提交 b5dd12fb 编写于 作者: Y Yanxing Shi 提交者: GitHub

fix softmax max dim (#37901)

上级 a8f009e4
......@@ -449,7 +449,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320;
constexpr int max_dim = 512;
constexpr int warps_per_block = 4;
if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
......@@ -540,7 +540,7 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320;
constexpr int max_dim = 512;
constexpr int warps_per_block = 4;
if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册