提交 763e8fdf 编写于 作者: Q Qiao Longfei

fix compile error

上级 fcde2b27
......@@ -117,17 +117,18 @@ __global__ static void ForRangeInElemwiseOp(Function func, T* vector,
template <>
struct ForRangeIn<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
: dev_ctx_(dev_ctx), range_(range) {}
template <typename Function>
inline void operator()(Function func) const {
constexpr int num_threads = 1024;
int block_size = range_.size() <= num_threads ? limit_ : num_threads;
int range_size = range_.size();
int block_size = range_size <= num_threads ? range_size : num_threads;
int grid_size = (range_.size() + num_threads - 1) / num_threads;
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, range_.data(), range_.size());
func, range_.data(), range_size);
}
const CUDADeviceContext& dev_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册