未验证 提交 d3cc7ac3 编写于 作者: F fengjiayi 提交者: GitHub

Fix top k op GPU code (#5221)

* Fix Type error

* Fix error

* Fix top_k_op GPU code data type
上级 b9056bb0
...@@ -23,9 +23,9 @@ using Tensor = framework::Tensor; ...@@ -23,9 +23,9 @@ using Tensor = framework::Tensor;
template <typename T> template <typename T>
struct Pair { struct Pair {
__device__ __forceinline__ Pair() {} __device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} __device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int id) { __device__ __forceinline__ void set(T value, int64_t id) {
v = value; v = value;
id = id; id = id;
} }
...@@ -48,7 +48,7 @@ struct Pair { ...@@ -48,7 +48,7 @@ struct Pair {
} }
T v; T v;
int id; int64_t id;
}; };
template <typename T> template <typename T>
...@@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam, ...@@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
template <typename T, int MaxLength, int BlockSize> template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid, __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal, Pair<T> topk[], T** topVal,
int** topIds, int& beam, int& k, int64_t** topIds, int& beam, int& k,
const int tid, const int warp) { const int tid, const int warp) {
while (true) { while (true) {
__syncthreads(); __syncthreads();
...@@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid, ...@@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 4. go to the first setp, until get the topk value. * 4. go to the first setp, until get the topk value.
*/ */
template <typename T, int MaxLength, int BlockSize> template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int* indices, __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k) { const T* src, int lds, int dim, int k) {
__shared__ Pair<T> sh_topk[BlockSize]; __shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2]; __shared__ int maxid[BlockSize / 2];
...@@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
int* indices_data = indices->mutable_data<int>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0]; size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1]; size_t input_width = input->dims()[1];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册