提交 bd966900 编写于 作者: W whs 提交者: ceci3

Make sequence_erase op support for input with multi-level LoD. (#15982)

test=develop
上级 1301dc1a
......@@ -64,8 +64,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
PADDLE_ENFORCE_EQ(lod[lod.size() - 1].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel();
......@@ -85,10 +84,9 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
num_erased.begin() + 1);
// Copy LoD to GPU
auto lod0 = lod[0];
auto lod_len = lod0.size();
const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace());
auto last_lod = lod[lod.size() - 1];
auto lod_len = last_lod.size();
const size_t* dev_in_lod_ptr = last_lod.CUDAData(ctx.GetPlace());
// Calc output LoD
thrust::device_vector<size_t> dev_out_lod(lod_len);
size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
......@@ -96,13 +94,16 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
// Set LoD for output
std::vector<size_t> out_lod0(dev_out_lod.begin(), dev_out_lod.end());
std::vector<size_t> out_last_lod(dev_out_lod.begin(), dev_out_lod.end());
framework::LoD out_lod;
out_lod.push_back(out_lod0);
for (size_t i = 0; i < lod.size() - 1; ++i) {
out_lod.push_back(lod[i]);
}
out_lod.push_back(out_last_lod);
out->set_lod(out_lod);
// Set output
out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
out->Resize({static_cast<int64_t>(out_last_lod.back()), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
......
......@@ -28,19 +28,18 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
PADDLE_ENFORCE_EQ(lod[lod.size() - 1].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel();
auto in_dat = in->data<T>();
auto lod0 = lod[0];
auto last_lod = lod[lod.size() - 1];
std::vector<size_t> num_erased(in_len + 1, 0);
std::vector<size_t> out_lod0(1, 0);
for (size_t i = 0; i < lod0.size() - 1; ++i) {
std::vector<size_t> out_last_lod(1, 0);
for (size_t i = 0; i < last_lod.size() - 1; ++i) {
size_t num_out = 0;
for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) {
for (auto j = last_lod[i] + 1; j <= last_lod[i + 1]; ++j) {
num_erased[j] = num_erased[j - 1];
if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) !=
tokens.end()) {
......@@ -49,7 +48,7 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
num_out += 1;
}
}
out_lod0.push_back(out_lod0.back() + num_out);
out_last_lod.push_back(out_last_lod.back() + num_out);
}
auto out_len = in_len - num_erased[in_len];
......@@ -62,7 +61,10 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
}
}
framework::LoD out_lod;
out_lod.push_back(out_lod0);
for (size_t i = 0; i < lod.size() - 1; ++i) {
out_lod.push_back(lod[i]);
}
out_lod.push_back(out_last_lod);
out->set_lod(out_lod);
}
};
......
......@@ -49,6 +49,21 @@ class TestSequenceEraseOpInt32(OpTest):
self.check_output()
class TestSequenceEraseOpInt32LoD2(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
lod = [[1, 3], [9, 4, 11, 6]]
tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[-1], tokens)
self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)}
self.outputs = {'Out': (out_seq, lod[:-1] + [new_lod0])}
def test_check_output(self):
self.check_output()
class TestSequenceEraseOpInt64(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册