未验证 提交 416e47ed 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] fix softmax with loss nan in HIP platform, test=develop (#31491)

上级 f57739be
......@@ -398,7 +398,12 @@ static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data,
const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
int64_t d, int axis_dim, int ignore_idx) {
#ifdef __HIPCC__
// HIP platform will have loss nan if dim size > 256
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册