From 10779460c5c127c257203d95d5f4740db4d55cad Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 10 Jan 2018 08:03:08 +0000 Subject: [PATCH] Simplify calc in test_sequence_erase_op --- paddle/operators/sequence_erase_op.cc | 2 +- paddle/operators/sequence_erase_op.cu | 2 + paddle/operators/sequence_erase_op.h | 25 ++++++---- .../v2/fluid/tests/test_sequence_erase_op.py | 46 +++++-------------- 4 files changed, 30 insertions(+), 45 deletions(-) diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc index 331970b3f..d17b26862 100644 --- a/paddle/operators/sequence_erase_op.cc +++ b/paddle/operators/sequence_erase_op.cc @@ -50,7 +50,7 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Sequence Erase Operator. -Sequence erase operator erases tokens specified by Attr(tokens) in the input +Sequence erase operator erases tokens specified by Attr(tokens) from the input sequences Input(X), and outputs the remaining data and modifies the LoD information at the same time. For example, given a 2-D LoDTensor diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index 3695a24cb..5da8eba3e 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -70,6 +70,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { auto lod = in->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); auto tokens = ctx.Attr>("tokens"); auto tokens_len = tokens.size(); auto in_len = in->numel(); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h index 92aa4a82b..cb2d7be00 100644 --- a/paddle/operators/sequence_erase_op.h +++ b/paddle/operators/sequence_erase_op.h @@ -28,22 +28,27 @@ class SequenceEraseKernel : public framework::OpKernel { auto lod = in->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); auto tokens = ctx.Attr>("tokens"); auto in_len = in->numel(); auto in_dat = in->data(); auto lod0 = lod[0]; + std::vector num_erased(in_len + 1, 0); - for (int64_t i = 1; i < in_len + 1; ++i) { - num_erased[i] = num_erased[i - 1]; - if (std::find(tokens.begin(), tokens.end(), in_dat[i - 1]) != - tokens.end()) { - num_erased[i] += 1; + std::vector out_lod0(1, 0); + for (size_t i = 0; i < lod0.size() - 1; ++i) { + size_t num_out = 0; + for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) { + num_erased[j] = num_erased[j - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) != + tokens.end()) { + num_erased[j] += 1; + } else { + num_out += 1; + } } - } - - std::vector out_lod0(lod0.size(), 0); - for (size_t i = 1; i < lod0.size(); ++i) { - out_lod0[i] = lod0[i] - num_erased[lod0[i]]; + out_lod0.push_back(out_lod0.back() + num_out); } auto out_len = in_len - num_erased[in_len]; diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py index 78105334f..bf257fefe 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -4,32 +4,23 @@ from op_test import OpTest def sequence_erase(in_seq, lod0, tokens): - # num_erased[i]: the number of elments to be removed before #i elements - num_erased = [0] * (len(in_seq) + 1) - for i in range(1, len(in_seq) + 1): - num_erased[i] = num_erased[i - 1] - if in_seq[i - 1] in tokens: - num_erased[i] += 1 - - # recalculate lod information - new_lod0 = [0] * len(lod0) - for i in range(1, len(lod0)): - new_lod0[i] = lod0[i] - num_erased[lod0[i]] - - out_seq = np.zeros( - (len(in_seq) - num_erased[len(in_seq)], 1)).astype("int32") - for i in range(0, len(in_seq)): - if num_erased[i] == num_erased[i + 1]: - out_seq[i - num_erased[i]] = in_seq[i] - # else in_seq[i] needs to be removed - return out_seq, new_lod0 + new_lod0 = [0] + out_seq = [] + for i in range(0, len(lod0) - 1): + num_out = 0 + for dat in in_seq[lod0[i]:lod0[i + 1]]: + if dat not in tokens: + out_seq.append(dat) + num_out += 1 + new_lod0.append(new_lod0[-1] + num_out) + return np.array(out_seq).astype("int32"), new_lod0 class TestSequenceEraseOp(OpTest): def setUp(self): self.op_type = "sequence_erase" - in_seq = np.random.randint(0, 10, (10, 1)).astype("int32") - lod = [[0, 3, 6, 10]] + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 9, 13, 24, 30]] tokens = [2, 3, 5] out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) self.attrs = {'tokens': tokens} @@ -41,17 +32,4 @@ class TestSequenceEraseOp(OpTest): if __name__ == '__main__': - """ - in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") - lod0 = [0, 5, 15, 30] - tokens = [2, 5] - out_seq, new_lod = sequence_erase(in_seq, lod0, tokens) - - print lod0, new_lod - print("compare") - for i in range(0, len(lod0)-1): - print(np.transpose(in_seq[lod0[i] : lod0[i+1]])) - print(np.transpose(out_seq[new_lod[i] : new_lod[i+1]])) - print("\n") - """ unittest.main() -- GitLab