diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index 5c0576dc5e329ef70330a4e571a7c547dfb05dc7..c1e8bc2090de98f78ae21eec40a7d324a03b7eff 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -23,13 +23,13 @@ using platform::PADDLE_CUDA_NUM_THREADS; using LoDTensor = framework::LoDTensor; template -__global__ void LabelErasedIdx(const T* in_dat, const int in_len, - const T* tokens, const int tokens_len, - int* num_erased) { +__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len, + const int* tokens, const size_t tokens_len, + size_t* num_erased) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < in_len) { int erased = 0; - for (int i = 0; i < tokens_len; ++i) { + for (size_t i = 0; i < tokens_len; ++i) { if (in_dat[index] == tokens[i]) { erased = 1; } @@ -41,9 +41,8 @@ __global__ void LabelErasedIdx(const T* in_dat, const int in_len, } } -template -__global__ void GetOutLod(const T* num_erased, const size_t* in_lod, - const int lod_len, size_t* out_lod0) { +__global__ void GetOutLod(const size_t* num_erased, const size_t* in_lod, + const size_t lod_len, size_t* out_lod0) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < lod_len) { out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; @@ -51,8 +50,8 @@ __global__ void GetOutLod(const T* num_erased, const size_t* in_lod, } template -__global__ void SetOutput(const T* in_dat, const int in_len, - const int* num_erased, T* out_dat) { +__global__ void SetOutput(const T* in_dat, const int64_t in_len, + const size_t* num_erased, T* out_dat) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < in_len) { if (num_erased[index] == num_erased[index + 1]) { @@ -92,17 +91,17 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), "The actual size mismatches with the LoD information."); - auto tokens = ctx.Attr>("tokens"); + auto tokens = ctx.Attr>("tokens"); auto in_len = in->numel(); auto in_dat = in->data(); // Copy tokens to GPU - thrust::device_vector dev_tokens = - set_device_vector>(tokens); - T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); + thrust::device_vector dev_tokens = + set_device_vector>(tokens); + int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); // Count number of elements to be erased - thrust::device_vector num_erased(in_len + 1); - int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); + thrust::device_vector num_erased(in_len + 1); + size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); auto stream = ctx.cuda_device_context().stream(); LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(