From d1d614b9f8ea054692c119fa107db2deb8963a40 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 17 Jan 2018 01:51:59 -0800 Subject: [PATCH] Refine the GPU kernel for sequence_erase_op --- paddle/operators/sequence_erase_op.cu | 69 ++++++++++++++++----------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index daf5b29863c..5c0576dc5e3 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -42,8 +42,8 @@ __global__ void LabelErasedIdx(const T* in_dat, const int in_len, } template -__global__ void GetOutLod(const T* num_erased, const int* in_lod, - const int lod_len, int* out_lod0) { +__global__ void GetOutLod(const T* num_erased, const size_t* in_lod, + const int lod_len, size_t* out_lod0) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < lod_len) { out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; @@ -61,6 +61,26 @@ __global__ void SetOutput(const T* in_dat, const int in_len, } } +template +thrust::device_vector set_device_vector(Vector& vector) { + thrust::host_vector host_vec(vector.size()); + for (size_t i = 0; i < vector.size(); ++i) { + host_vec[i] = vector[i]; + } + thrust::device_vector dev_vec = host_vec; + return dev_vec; +} + +template +std::vector get_std_vector(thrust::device_vector& dev_vec) { + thrust::host_vector host_vec = dev_vec; + std::vector std_vec(host_vec.size(), 0); + for (size_t i = 0; i < host_vec.size(); ++i) { + std_vec[i] = host_vec[i]; + } + return std_vec; +} + template class SequenceEraseOpCUDAKernel : public framework::OpKernel { public: @@ -73,52 +93,45 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), "The actual size mismatches with the LoD information."); auto tokens = ctx.Attr>("tokens"); - auto tokens_len = tokens.size(); auto in_len = in->numel(); auto in_dat = in->data(); - auto lod0 = lod[0]; + // Copy tokens to GPU + thrust::device_vector dev_tokens = + set_device_vector>(tokens); + T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); - thrust::host_vector host_tokens(tokens_len); - for (size_t i = 0; i < tokens.size(); ++i) { - host_tokens[i] = tokens[i]; - } - thrust::device_vector dev_tokens = host_tokens; + // Count number of elements to be erased thrust::device_vector num_erased(in_len + 1); - - T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); - auto stream = ctx.cuda_device_context().stream(); LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr); + in_dat, in_len, dev_tokens_ptr, tokens.size(), num_erased_ptr); thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), num_erased.begin() + 1); - // Calc LoD + // Copy LoD to GPU + auto lod0 = lod[0]; auto lod_len = lod0.size(); - thrust::host_vector host_lod(lod_len); - for (size_t i = 0; i < lod_len; ++i) { - host_lod[i] = lod0[i]; - } - thrust::device_vector dev_in_lod = host_lod; - thrust::device_vector dev_out_lod(lod_len); - int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); - int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); + thrust::device_vector dev_in_lod = + set_device_vector>(lod0); + size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); + + // Calc output LoD + thrust::device_vector dev_out_lod(lod_len); + size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); - thrust::host_vector host_out_lod = dev_out_lod; - std::vector out_lod0(lod_len, 0); - for (size_t i = 0; i < lod_len; i++) { - out_lod0[i] = host_out_lod[i]; - } + + // Set LoD for output + std::vector out_lod0 = get_std_vector(dev_out_lod); framework::LoD out_lod; out_lod.push_back(out_lod0); out->set_lod(out_lod); // Set output - out->Resize({out_lod0.back(), 1}); + out->Resize({static_cast(out_lod0.back()), 1}); auto out_dat = out->mutable_data(ctx.GetPlace()); SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, -- GitLab