diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc index d17b2686238b2d2f872331edfdbb095fb8693b87..aa0c00aa6f7854ee5e34aef78970971b78df6514 100644 --- a/paddle/operators/sequence_erase_op.cc +++ b/paddle/operators/sequence_erase_op.cc @@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, ops::SequenceEraseOpMaker); REGISTER_OP_CPU_KERNEL( sequence_erase, - ops::SequenceEraseKernel); + ops::SequenceEraseKernel, + ops::SequenceEraseKernel); diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index c1e8bc2090de98f78ae21eec40a7d324a03b7eff..f1e3b96acd0259de2b3ca1348834bd17e1e174a2 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len, size_t* num_erased) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < in_len) { - int erased = 0; for (size_t i = 0; i < tokens_len; ++i) { if (in_dat[index] == tokens[i]) { - erased = 1; + num_erased[index + 1] = 1; + break; } } - num_erased[index + 1] = erased; - if (index == 0) { - num_erased[0] = 0; - } } } @@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len, } } -template -thrust::device_vector set_device_vector(Vector& vector) { - thrust::host_vector host_vec(vector.size()); - for (size_t i = 0; i < vector.size(); ++i) { - host_vec[i] = vector[i]; - } - thrust::device_vector dev_vec = host_vec; - return dev_vec; -} - -template -std::vector get_std_vector(thrust::device_vector& dev_vec) { - thrust::host_vector host_vec = dev_vec; - std::vector std_vec(host_vec.size(), 0); - for (size_t i = 0; i < host_vec.size(); ++i) { - std_vec[i] = host_vec[i]; - } - return std_vec; -} - template class SequenceEraseOpCUDAKernel : public framework::OpKernel { public: @@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { auto in_len = in->numel(); auto in_dat = in->data(); // Copy tokens to GPU - thrust::device_vector dev_tokens = - set_device_vector>(tokens); + thrust::device_vector dev_tokens(tokens.begin(), tokens.end()); int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); // Count number of elements to be erased - thrust::device_vector num_erased(in_len + 1); + thrust::device_vector num_erased(in_len + 1, 0); size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); auto stream = ctx.cuda_device_context().stream(); LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, @@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { // Copy LoD to GPU auto lod0 = lod[0]; auto lod_len = lod0.size(); - thrust::device_vector dev_in_lod = - set_device_vector>(lod0); + thrust::device_vector dev_in_lod = lod0; size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); // Calc output LoD @@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); // Set LoD for output - std::vector out_lod0 = get_std_vector(dev_out_lod); + thrust::host_vector out_lod0 = dev_out_lod; framework::LoD out_lod; out_lod.push_back(out_lod0); out->set_lod(out_lod); @@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { } // namespace paddle REGISTER_OP_CUDA_KERNEL(sequence_erase, - paddle::operators::SequenceEraseOpCUDAKernel); + paddle::operators::SequenceEraseOpCUDAKernel, + paddle::operators::SequenceEraseOpCUDAKernel); diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py index d8aa4f7e9414b59b9b029decdf55b8108ef0047b..4cc2613cf9c26845cef988160405b632706c4b11 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens): return np.array(out_seq).astype("int32"), new_lod0 -class TestSequenceEraseOp(OpTest): +class TestSequenceEraseOpInt32(OpTest): def setUp(self): self.op_type = "sequence_erase" in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") @@ -44,6 +44,21 @@ class TestSequenceEraseOp(OpTest): self.check_output() +class TestSequenceEraseOpInt64(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int64") + lod = [[0, 9, 13, 24, 30]] + tokens = [2, 3, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + class TestSequenceEraseOpEmpty(OpTest): def setUp(self): self.op_type = "sequence_erase"