未验证 提交 d3cc7ac3 编写于 作者: F fengjiayi 提交者: GitHub

Fix top k op GPU code (#5221)

* Fix Type error

* Fix error

* Fix top_k_op GPU code data type
上级 b9056bb0
......@@ -23,9 +23,9 @@ using Tensor = framework::Tensor;
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
__device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int id) {
__device__ __forceinline__ void set(T value, int64_t id) {
v = value;
id = id;
}
......@@ -48,7 +48,7 @@ struct Pair {
}
T v;
int id;
int64_t id;
};
template <typename T>
......@@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
int** topIds, int& beam, int& k,
int64_t** topIds, int& beam, int& k,
const int tid, const int warp) {
while (true) {
__syncthreads();
......@@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
......@@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T?
int* indices_data = indices->mutable_data<int>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册