未验证 提交 46dd1d4a 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] fix reduce_sum nan in ROCM platform, test=develop (#31780)

上级 f72d197e
...@@ -161,7 +161,11 @@ static inline std::vector<int> GetStrides(const std::vector<int>& dims, ...@@ -161,7 +161,11 @@ static inline std::vector<int> GetStrides(const std::vector<int>& dims,
return strides; return strides;
} }
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512; constexpr int kMaxBlockDim = 512;
#endif
static inline int GetDesiredBlockDim(int block_dim) { static inline int GetDesiredBlockDim(int block_dim) {
return block_dim >= kMaxBlockDim return block_dim >= kMaxBlockDim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册