提交 d1d614b9 编写于 作者: Y Yibing Liu

Refine the GPU kernel for sequence_erase_op

上级 7d3b2e4b
...@@ -42,8 +42,8 @@ __global__ void LabelErasedIdx(const T* in_dat, const int in_len, ...@@ -42,8 +42,8 @@ __global__ void LabelErasedIdx(const T* in_dat, const int in_len,
} }
template <typename T> template <typename T>
__global__ void GetOutLod(const T* num_erased, const int* in_lod, __global__ void GetOutLod(const T* num_erased, const size_t* in_lod,
const int lod_len, int* out_lod0) { const int lod_len, size_t* out_lod0) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < lod_len) { if (index < lod_len) {
out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
...@@ -61,6 +61,26 @@ __global__ void SetOutput(const T* in_dat, const int in_len, ...@@ -61,6 +61,26 @@ __global__ void SetOutput(const T* in_dat, const int in_len,
} }
} }
template <typename T, typename Vector>
thrust::device_vector<T> set_device_vector(Vector& vector) {
thrust::host_vector<T> host_vec(vector.size());
for (size_t i = 0; i < vector.size(); ++i) {
host_vec[i] = vector[i];
}
thrust::device_vector<T> dev_vec = host_vec;
return dev_vec;
}
template <typename T>
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
thrust::host_vector<T> host_vec = dev_vec;
std::vector<T> std_vec(host_vec.size(), 0);
for (size_t i = 0; i < host_vec.size(); ++i) {
std_vec[i] = host_vec[i];
}
return std_vec;
}
template <typename T> template <typename T>
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -73,52 +93,45 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -73,52 +93,45 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information."); "The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<T>>("tokens"); auto tokens = ctx.Attr<std::vector<T>>("tokens");
auto tokens_len = tokens.size();
auto in_len = in->numel(); auto in_len = in->numel();
auto in_dat = in->data<T>(); auto in_dat = in->data<T>();
auto lod0 = lod[0]; // Copy tokens to GPU
thrust::device_vector<T> dev_tokens =
set_device_vector<T, std::vector<T>>(tokens);
T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
thrust::host_vector<T> host_tokens(tokens_len); // Count number of elements to be erased
for (size_t i = 0; i < tokens.size(); ++i) {
host_tokens[i] = tokens[i];
}
thrust::device_vector<T> dev_tokens = host_tokens;
thrust::device_vector<int> num_erased(in_len + 1); thrust::device_vector<int> num_erased(in_len + 1);
T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr); in_dat, in_len, dev_tokens_ptr, tokens.size(), num_erased_ptr);
thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
num_erased.begin() + 1); num_erased.begin() + 1);
// Calc LoD // Copy LoD to GPU
auto lod0 = lod[0];
auto lod_len = lod0.size(); auto lod_len = lod0.size();
thrust::host_vector<int> host_lod(lod_len); thrust::device_vector<size_t> dev_in_lod =
for (size_t i = 0; i < lod_len; ++i) { set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
host_lod[i] = lod0[i]; size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
}
thrust::device_vector<int> dev_in_lod = host_lod; // Calc output LoD
thrust::device_vector<int> dev_out_lod(lod_len); thrust::device_vector<size_t> dev_out_lod(lod_len);
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
thrust::host_vector<int> host_out_lod = dev_out_lod;
std::vector<int> out_lod0(lod_len, 0); // Set LoD for output
for (size_t i = 0; i < lod_len; i++) { std::vector<size_t> out_lod0 = get_std_vector<size_t>(dev_out_lod);
out_lod0[i] = host_out_lod[i];
}
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); out_lod.push_back(out_lod0);
out->set_lod(out_lod); out->set_lod(out_lod);
// Set output // Set output
out->Resize({out_lod0.back(), 1}); out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace()); auto out_dat = out->mutable_data<T>(ctx.GetPlace());
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册