未验证 提交 3ff5cc2d 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix topk compile failed on windows (#21243)

* Fix topk compile failed on windows
* Use explicit cast for assign data
上级 2e2f92a5
...@@ -336,12 +336,13 @@ struct ColumnIndexIter { ...@@ -336,12 +336,13 @@ struct ColumnIndexIter {
int num_cols_; int num_cols_;
}; };
__global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) { __global__ void InitIndex(int64_t* indices, int64_t num_rows,
int64_t num_cols) {
int col_id = threadIdx.x; int col_id = threadIdx.x;
int row_id = blockIdx.x; int row_id = blockIdx.x;
for (int j = row_id; j < num_rows; j += gridDim.x) { for (int64_t j = row_id; j < num_rows; j += gridDim.x) {
for (int i = col_id; i < num_cols; i += blockDim.x) { for (int64_t i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i; indices[j * num_cols + i] = i;
} }
} }
...@@ -349,14 +350,14 @@ __global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) { ...@@ -349,14 +350,14 @@ __global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
template <typename T> template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx, bool SortTopk(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const size_t num_cols, const framework::Tensor* input_tensor, const int64_t num_cols,
const size_t num_rows, size_t k, framework::Tensor* out_tensor, const int64_t num_rows, const int k,
framework::Tensor* out_tensor,
framework::Tensor* indices_tensor) { framework::Tensor* indices_tensor) {
auto cu_stream = ctx.stream(); auto cu_stream = ctx.stream();
Tensor input_indices; Tensor input_indices;
const std::vector<int64_t> dims = {static_cast<int64_t>(num_rows), const std::vector<int64_t> dims = {num_rows, num_cols};
static_cast<int64_t>(num_cols)};
auto dim = framework::make_ddim(dims); auto dim = framework::make_ddim(dims);
input_indices.Resize(dim); input_indices.Resize(dim);
// input_indices.Resize(num_rows*num_cols); // input_indices.Resize(num_rows*num_cols);
...@@ -378,18 +379,20 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -378,18 +379,20 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
int block_size = ComputeBlockSize(num_cols); int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size // actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
// Init a index array // Init a index array
InitIndex<<<grid_size, block_size, 0, cu_stream>>>( InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols); input_indices.data<int64_t>(), num_rows, num_cols);
// create iter for counting input // create iter for counting input
cub::CountingInputIterator<int> counting_iter(0); cub::CountingInputIterator<int64_t> counting_iter(0);
// segment_offset is used for move to next row // segment_offset is used for move to next row
cub::TransformInputIterator<int, SegmentOffsetIter, cub::TransformInputIterator<int64_t, SegmentOffsetIter,
cub::CountingInputIterator<int>> cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr; T* sorted_values_ptr;
...@@ -484,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -484,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
size_t k = static_cast<int>(ctx.Attr<int>("k")); int k = static_cast<int>(ctx.Attr<int>("k"));
auto* k_t = ctx.Input<Tensor>("K"); auto* k_t = ctx.Input<Tensor>("K");
if (k_t) { if (k_t) {
...@@ -502,9 +505,9 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -502,9 +505,9 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t input_height = framework::product( const int64_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1]; const int64_t input_width = inputdims[inputdims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context(); const auto& dev_ctx = ctx.cuda_device_context();
if ((input_width <= 1024 || k >= 128 || k == input_width)) { if ((input_width <= 1024 || k >= 128 || k == input_width)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册