提交 10779460 编写于 作者: Y Yibing Liu

Simplify calc in test_sequence_erase_op

上级 7b9d5b32
...@@ -50,7 +50,7 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,7 +50,7 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Sequence Erase Operator. Sequence Erase Operator.
Sequence erase operator erases tokens specified by Attr(tokens) in the input Sequence erase operator erases tokens specified by Attr(tokens) from the input
sequences Input(X), and outputs the remaining data and modifies the LoD sequences Input(X), and outputs the remaining data and modifies the LoD
information at the same time. For example, given a 2-D LoDTensor information at the same time. For example, given a 2-D LoDTensor
......
...@@ -70,6 +70,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -70,6 +70,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
auto lod = in->lod(); auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<T>>("tokens"); auto tokens = ctx.Attr<std::vector<T>>("tokens");
auto tokens_len = tokens.size(); auto tokens_len = tokens.size();
auto in_len = in->numel(); auto in_len = in->numel();
......
...@@ -28,22 +28,27 @@ class SequenceEraseKernel : public framework::OpKernel<T> { ...@@ -28,22 +28,27 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
auto lod = in->lod(); auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens"); auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel(); auto in_len = in->numel();
auto in_dat = in->data<T>(); auto in_dat = in->data<T>();
auto lod0 = lod[0]; auto lod0 = lod[0];
std::vector<size_t> num_erased(in_len + 1, 0); std::vector<size_t> num_erased(in_len + 1, 0);
for (int64_t i = 1; i < in_len + 1; ++i) { std::vector<size_t> out_lod0(1, 0);
num_erased[i] = num_erased[i - 1]; for (size_t i = 0; i < lod0.size() - 1; ++i) {
if (std::find(tokens.begin(), tokens.end(), in_dat[i - 1]) != size_t num_out = 0;
tokens.end()) { for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) {
num_erased[i] += 1; num_erased[j] = num_erased[j - 1];
if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) !=
tokens.end()) {
num_erased[j] += 1;
} else {
num_out += 1;
}
} }
} out_lod0.push_back(out_lod0.back() + num_out);
std::vector<size_t> out_lod0(lod0.size(), 0);
for (size_t i = 1; i < lod0.size(); ++i) {
out_lod0[i] = lod0[i] - num_erased[lod0[i]];
} }
auto out_len = in_len - num_erased[in_len]; auto out_len = in_len - num_erased[in_len];
......
...@@ -4,32 +4,23 @@ from op_test import OpTest ...@@ -4,32 +4,23 @@ from op_test import OpTest
def sequence_erase(in_seq, lod0, tokens): def sequence_erase(in_seq, lod0, tokens):
# num_erased[i]: the number of elments to be removed before #i elements new_lod0 = [0]
num_erased = [0] * (len(in_seq) + 1) out_seq = []
for i in range(1, len(in_seq) + 1): for i in range(0, len(lod0) - 1):
num_erased[i] = num_erased[i - 1] num_out = 0
if in_seq[i - 1] in tokens: for dat in in_seq[lod0[i]:lod0[i + 1]]:
num_erased[i] += 1 if dat not in tokens:
out_seq.append(dat)
# recalculate lod information num_out += 1
new_lod0 = [0] * len(lod0) new_lod0.append(new_lod0[-1] + num_out)
for i in range(1, len(lod0)): return np.array(out_seq).astype("int32"), new_lod0
new_lod0[i] = lod0[i] - num_erased[lod0[i]]
out_seq = np.zeros(
(len(in_seq) - num_erased[len(in_seq)], 1)).astype("int32")
for i in range(0, len(in_seq)):
if num_erased[i] == num_erased[i + 1]:
out_seq[i - num_erased[i]] = in_seq[i]
# else in_seq[i] needs to be removed
return out_seq, new_lod0
class TestSequenceEraseOp(OpTest): class TestSequenceEraseOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sequence_erase" self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (10, 1)).astype("int32") in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
lod = [[0, 3, 6, 10]] lod = [[0, 9, 13, 24, 30]]
tokens = [2, 3, 5] tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
self.attrs = {'tokens': tokens} self.attrs = {'tokens': tokens}
...@@ -41,17 +32,4 @@ class TestSequenceEraseOp(OpTest): ...@@ -41,17 +32,4 @@ class TestSequenceEraseOp(OpTest):
if __name__ == '__main__': if __name__ == '__main__':
"""
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
lod0 = [0, 5, 15, 30]
tokens = [2, 5]
out_seq, new_lod = sequence_erase(in_seq, lod0, tokens)
print lod0, new_lod
print("compare")
for i in range(0, len(lod0)-1):
print(np.transpose(in_seq[lod0[i] : lod0[i+1]]))
print(np.transpose(out_seq[new_lod[i] : new_lod[i+1]]))
print("\n")
"""
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册