未验证 提交 3ff5cc2d 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix topk compile failed on windows (#21243)

* Fix topk compile failed on windows
* Use explicit cast for assign data
上级 2e2f92a5
......@@ -336,12 +336,13 @@ struct ColumnIndexIter {
int num_cols_;
};
__global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
__global__ void InitIndex(int64_t* indices, int64_t num_rows,
int64_t num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (int j = row_id; j < num_rows; j += gridDim.x) {
for (int i = col_id; i < num_cols; i += blockDim.x) {
for (int64_t j = row_id; j < num_rows; j += gridDim.x) {
for (int64_t i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
......@@ -349,14 +350,14 @@ __global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const size_t num_cols,
const size_t num_rows, size_t k, framework::Tensor* out_tensor,
const framework::Tensor* input_tensor, const int64_t num_cols,
const int64_t num_rows, const int k,
framework::Tensor* out_tensor,
framework::Tensor* indices_tensor) {
auto cu_stream = ctx.stream();
Tensor input_indices;
const std::vector<int64_t> dims = {static_cast<int64_t>(num_rows),
static_cast<int64_t>(num_cols)};
const std::vector<int64_t> dims = {num_rows, num_cols};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
// input_indices.Resize(num_rows*num_cols);
......@@ -378,18 +379,20 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
// Init a index array
InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
// create iter for counting input
cub::CountingInputIterator<int> counting_iter(0);
cub::CountingInputIterator<int64_t> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<int, SegmentOffsetIter,
cub::CountingInputIterator<int>>
cub::TransformInputIterator<int64_t, SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
......@@ -484,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
size_t k = static_cast<int>(ctx.Attr<int>("k"));
int k = static_cast<int>(ctx.Attr<int>("k"));
auto* k_t = ctx.Input<Tensor>("K");
if (k_t) {
......@@ -502,9 +505,9 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): data is always converted to type T?
framework::DDim inputdims = input->dims();
const size_t input_height = framework::product(
const int64_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
const int64_t input_width = inputdims[inputdims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册