提交 763e8fdf 编写于 作者: Q Qiao Longfei

fix compile error

上级 fcde2b27
...@@ -117,17 +117,18 @@ __global__ static void ForRangeInElemwiseOp(Function func, T* vector, ...@@ -117,17 +117,18 @@ __global__ static void ForRangeInElemwiseOp(Function func, T* vector,
template <> template <>
struct ForRangeIn<CUDADeviceContext> { struct ForRangeIn<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range) ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
: dev_ctx_(dev_ctx), range_(range) {} : dev_ctx_(dev_ctx), range_(range) {}
template <typename Function> template <typename Function>
inline void operator()(Function func) const { inline void operator()(Function func) const {
constexpr int num_threads = 1024; constexpr int num_threads = 1024;
int block_size = range_.size() <= num_threads ? limit_ : num_threads; int range_size = range_.size();
int block_size = range_size <= num_threads ? range_size : num_threads;
int grid_size = (range_.size() + num_threads - 1) / num_threads; int grid_size = (range_.size() + num_threads - 1) / num_threads;
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>( ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, range_.data(), range_.size()); func, range_.data(), range_size);
} }
const CUDADeviceContext& dev_ctx_; const CUDADeviceContext& dev_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册