提交 8809d43a 编写于 作者: Y Yibing Liu

Remove unnecessary dtype conversion & register int64 kernels

上级 7a2aa486
...@@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, ...@@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
ops::SequenceEraseOpMaker); ops::SequenceEraseOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_erase, sequence_erase,
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>); ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len, ...@@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
size_t* num_erased) { size_t* num_erased) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) { if (index < in_len) {
int erased = 0;
for (size_t i = 0; i < tokens_len; ++i) { for (size_t i = 0; i < tokens_len; ++i) {
if (in_dat[index] == tokens[i]) { if (in_dat[index] == tokens[i]) {
erased = 1; num_erased[index + 1] = 1;
break;
} }
} }
num_erased[index + 1] = erased;
if (index == 0) {
num_erased[0] = 0;
}
} }
} }
...@@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len, ...@@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len,
} }
} }
template <typename T, typename Vector>
thrust::device_vector<T> set_device_vector(Vector& vector) {
thrust::host_vector<T> host_vec(vector.size());
for (size_t i = 0; i < vector.size(); ++i) {
host_vec[i] = vector[i];
}
thrust::device_vector<T> dev_vec = host_vec;
return dev_vec;
}
template <typename T>
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
thrust::host_vector<T> host_vec = dev_vec;
std::vector<T> std_vec(host_vec.size(), 0);
for (size_t i = 0; i < host_vec.size(); ++i) {
std_vec[i] = host_vec[i];
}
return std_vec;
}
template <typename T> template <typename T>
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
auto in_len = in->numel(); auto in_len = in->numel();
auto in_dat = in->data<T>(); auto in_dat = in->data<T>();
// Copy tokens to GPU // Copy tokens to GPU
thrust::device_vector<int> dev_tokens = thrust::device_vector<int> dev_tokens(tokens.begin(), tokens.end());
set_device_vector<int, std::vector<int>>(tokens);
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
// Count number of elements to be erased // Count number of elements to be erased
thrust::device_vector<size_t> num_erased(in_len + 1); thrust::device_vector<size_t> num_erased(in_len + 1, 0);
size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
...@@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
// Copy LoD to GPU // Copy LoD to GPU
auto lod0 = lod[0]; auto lod0 = lod[0];
auto lod_len = lod0.size(); auto lod_len = lod0.size();
thrust::device_vector<size_t> dev_in_lod = thrust::device_vector<size_t> dev_in_lod = lod0;
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
// Calc output LoD // Calc output LoD
...@@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
// Set LoD for output // Set LoD for output
std::vector<size_t> out_lod0 = get_std_vector<size_t>(dev_out_lod); thrust::host_vector<size_t> out_lod0 = dev_out_lod;
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); out_lod.push_back(out_lod0);
out->set_lod(out_lod); out->set_lod(out_lod);
...@@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(sequence_erase, REGISTER_OP_CUDA_KERNEL(sequence_erase,
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>); paddle::operators::SequenceEraseOpCUDAKernel<int32_t>,
paddle::operators::SequenceEraseOpCUDAKernel<int64_t>);
...@@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens): ...@@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens):
return np.array(out_seq).astype("int32"), new_lod0 return np.array(out_seq).astype("int32"), new_lod0
class TestSequenceEraseOp(OpTest): class TestSequenceEraseOpInt32(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sequence_erase" self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
...@@ -44,6 +44,21 @@ class TestSequenceEraseOp(OpTest): ...@@ -44,6 +44,21 @@ class TestSequenceEraseOp(OpTest):
self.check_output() self.check_output()
class TestSequenceEraseOpInt64(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
lod = [[0, 9, 13, 24, 30]]
tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)}
self.outputs = {'Out': (out_seq, [new_lod0])}
def test_check_output(self):
self.check_output()
class TestSequenceEraseOpEmpty(OpTest): class TestSequenceEraseOpEmpty(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sequence_erase" self.op_type = "sequence_erase"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册