diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index c5cc741820502b72bd332e9b2d28defc312bfba0..0375611dfb5b9c129e9cfa26808ec40185eac47d 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -336,12 +336,13 @@ struct ColumnIndexIter { int num_cols_; }; -__global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) { +__global__ void InitIndex(int64_t* indices, int64_t num_rows, + int64_t num_cols) { int col_id = threadIdx.x; int row_id = blockIdx.x; - for (int j = row_id; j < num_rows; j += gridDim.x) { - for (int i = col_id; i < num_cols; i += blockDim.x) { + for (int64_t j = row_id; j < num_rows; j += gridDim.x) { + for (int64_t i = col_id; i < num_cols; i += blockDim.x) { indices[j * num_cols + i] = i; } } @@ -349,14 +350,14 @@ __global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) { template bool SortTopk(const platform::CUDADeviceContext& ctx, - const framework::Tensor* input_tensor, const size_t num_cols, - const size_t num_rows, size_t k, framework::Tensor* out_tensor, + const framework::Tensor* input_tensor, const int64_t num_cols, + const int64_t num_rows, const int k, + framework::Tensor* out_tensor, framework::Tensor* indices_tensor) { auto cu_stream = ctx.stream(); Tensor input_indices; - const std::vector dims = {static_cast(num_rows), - static_cast(num_cols)}; + const std::vector dims = {num_rows, num_cols}; auto dim = framework::make_ddim(dims); input_indices.Resize(dim); // input_indices.Resize(num_rows*num_cols); @@ -378,18 +379,20 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, int block_size = ComputeBlockSize(num_cols); - int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; + unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; // actually, int num_rows < max_grid_size - int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + unsigned int grid_size = num_rows < maxGridDimX + ? static_cast(num_rows) + : maxGridDimX; // Init a index array InitIndex<<>>( input_indices.data(), num_rows, num_cols); // create iter for counting input - cub::CountingInputIterator counting_iter(0); + cub::CountingInputIterator counting_iter(0); // segment_offset is used for move to next row - cub::TransformInputIterator> + cub::TransformInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); T* sorted_values_ptr; @@ -484,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); - size_t k = static_cast(ctx.Attr("k")); + int k = static_cast(ctx.Attr("k")); auto* k_t = ctx.Input("K"); if (k_t) { @@ -502,9 +505,9 @@ class TopkOpCUDAKernel : public framework::OpKernel { // FIXME(typhoonzero): data is always converted to type T? framework::DDim inputdims = input->dims(); - const size_t input_height = framework::product( + const int64_t input_height = framework::product( framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); - const size_t input_width = inputdims[inputdims.size() - 1]; + const int64_t input_width = inputdims[inputdims.size() - 1]; const auto& dev_ctx = ctx.cuda_device_context(); if ((input_width <= 1024 || k >= 128 || k == input_width)) {