提交 bd966900 编写于 作者: W whs 提交者: ceci3

Make sequence_erase op support for input with multi-level LoD. (#15982)

test=develop
上级 1301dc1a
...@@ -64,8 +64,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -64,8 +64,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
auto lod = in->lod(); auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod[lod.size() - 1].back(), (size_t)in->numel(),
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information."); "The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens"); auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel(); auto in_len = in->numel();
...@@ -85,10 +84,9 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -85,10 +84,9 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
num_erased.begin() + 1); num_erased.begin() + 1);
// Copy LoD to GPU // Copy LoD to GPU
auto lod0 = lod[0]; auto last_lod = lod[lod.size() - 1];
auto lod_len = lod0.size(); auto lod_len = last_lod.size();
const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace()); const size_t* dev_in_lod_ptr = last_lod.CUDAData(ctx.GetPlace());
// Calc output LoD // Calc output LoD
thrust::device_vector<size_t> dev_out_lod(lod_len); thrust::device_vector<size_t> dev_out_lod(lod_len);
size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
...@@ -96,13 +94,16 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -96,13 +94,16 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
// Set LoD for output // Set LoD for output
std::vector<size_t> out_lod0(dev_out_lod.begin(), dev_out_lod.end()); std::vector<size_t> out_last_lod(dev_out_lod.begin(), dev_out_lod.end());
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); for (size_t i = 0; i < lod.size() - 1; ++i) {
out_lod.push_back(lod[i]);
}
out_lod.push_back(out_last_lod);
out->set_lod(out_lod); out->set_lod(out_lod);
// Set output // Set output
out->Resize({static_cast<int64_t>(out_lod0.back()), 1}); out->Resize({static_cast<int64_t>(out_last_lod.back()), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace()); auto out_dat = out->mutable_data<T>(ctx.GetPlace());
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
......
...@@ -28,19 +28,18 @@ class SequenceEraseKernel : public framework::OpKernel<T> { ...@@ -28,19 +28,18 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
auto lod = in->lod(); auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod[lod.size() - 1].back(), (size_t)in->numel(),
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information."); "The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens"); auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel(); auto in_len = in->numel();
auto in_dat = in->data<T>(); auto in_dat = in->data<T>();
auto lod0 = lod[0]; auto last_lod = lod[lod.size() - 1];
std::vector<size_t> num_erased(in_len + 1, 0); std::vector<size_t> num_erased(in_len + 1, 0);
std::vector<size_t> out_lod0(1, 0); std::vector<size_t> out_last_lod(1, 0);
for (size_t i = 0; i < lod0.size() - 1; ++i) { for (size_t i = 0; i < last_lod.size() - 1; ++i) {
size_t num_out = 0; size_t num_out = 0;
for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) { for (auto j = last_lod[i] + 1; j <= last_lod[i + 1]; ++j) {
num_erased[j] = num_erased[j - 1]; num_erased[j] = num_erased[j - 1];
if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) != if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) !=
tokens.end()) { tokens.end()) {
...@@ -49,7 +48,7 @@ class SequenceEraseKernel : public framework::OpKernel<T> { ...@@ -49,7 +48,7 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
num_out += 1; num_out += 1;
} }
} }
out_lod0.push_back(out_lod0.back() + num_out); out_last_lod.push_back(out_last_lod.back() + num_out);
} }
auto out_len = in_len - num_erased[in_len]; auto out_len = in_len - num_erased[in_len];
...@@ -62,7 +61,10 @@ class SequenceEraseKernel : public framework::OpKernel<T> { ...@@ -62,7 +61,10 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
} }
} }
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); for (size_t i = 0; i < lod.size() - 1; ++i) {
out_lod.push_back(lod[i]);
}
out_lod.push_back(out_last_lod);
out->set_lod(out_lod); out->set_lod(out_lod);
} }
}; };
......
...@@ -49,6 +49,21 @@ class TestSequenceEraseOpInt32(OpTest): ...@@ -49,6 +49,21 @@ class TestSequenceEraseOpInt32(OpTest):
self.check_output() self.check_output()
class TestSequenceEraseOpInt32LoD2(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
lod = [[1, 3], [9, 4, 11, 6]]
tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[-1], tokens)
self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)}
self.outputs = {'Out': (out_seq, lod[:-1] + [new_lod0])}
def test_check_output(self):
self.check_output()
class TestSequenceEraseOpInt64(OpTest): class TestSequenceEraseOpInt64(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sequence_erase" self.op_type = "sequence_erase"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册