From e85c51330700a4125f8574e8c0927407c6d9e3d0 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 7 Jan 2018 13:14:54 +0000 Subject: [PATCH] Add sequencee erase operator --- paddle/operators/sequence_erase_op.cc | 61 ++++++++++++++ paddle/operators/sequence_erase_op.h | 80 +++++++++++++++++++ .../v2/fluid/tests/test_sequence_erase_op.py | 58 ++++++++++++++ 3 files changed, 199 insertions(+) create mode 100644 paddle/operators/sequence_erase_op.cc create mode 100644 paddle/operators/sequence_erase_op.h create mode 100644 python/paddle/v2/fluid/tests/test_sequence_erase_op.py diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc new file mode 100644 index 00000000000..e611ef0571d --- /dev/null +++ b/paddle/operators/sequence_erase_op.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_erase_op.h" + +namespace paddle { +namespace operators { + +class SequenceEraseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceEraseOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceEraseOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } +}; + +class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor) 2-D input LoDTensor with the 2-nd dimension " + "of length 1."); + AddOutput("Out", + "(LoDTensor) 2-D output LoDTensor with the 2-nd dimension " + "of length 1."); + AddAttr>("tokens", + "(vector) " + "Tokens to be removed from input."); + AddComment(R"DOC( +Sequence Erase Operator. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, + ops::SequenceEraseOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_erase, + ops::SequenceEraseKernel); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h new file mode 100644 index 00000000000..937b9870aa9 --- /dev/null +++ b/paddle/operators/sequence_erase_op.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class SequenceEraseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + // auto dims = x->dims(); + /* + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(dims[0], static_cast(lod[level].back()), + "The first dimension of Input(X) should be equal to the " + "sum of all sequences' lengths."); + PADDLE_ENFORCE_EQ(dims[0], x->numel(), + "The width of each timestep in Input(X) of " + "SequenceEraseOp should be 1."); + out->mutable_data(ctx.GetPlace()); + */ + auto tokens = ctx.Attr>("tokens"); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + std::vector num_erased(in_len + 1, 0); + for (int64_t i = 1; i < in_len + 1; ++i) { + num_erased[i] = num_erased[i - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[i - 1]) != + tokens.end()) { + num_erased[i] += 1; + } + } + + std::vector out_lod0(lod0.size(), 0); + for (size_t i = 1; i < lod0.size(); ++i) { + out_lod0[i] = lod0[i] - num_erased[lod0[i]]; + } + + auto out_len = in_len - num_erased[in_len]; + out->Resize({static_cast(out_len), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + + for (size_t i = 0; i < in_len; ++i) { + if (num_erased[i] == num_erased[i + 1]) { + out_dat[i - num_erased[i]] = in_dat[i]; + } + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py new file mode 100644 index 00000000000..74274cf0ad4 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -0,0 +1,58 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_erase(in_seq, lod0, tokens): + # num_erased[i]: the number of elments to be removed before #i elements + num_erased = [0] * (len(in_seq) + 1) + for i in range(1, len(in_seq) + 1): + num_erased[i] = num_erased[i - 1] + if in_seq[i - 1] in tokens: + num_erased[i] += 1 + + # recalculate lod information + new_lod0 = [0] * len(lod0) + for i in range(1, len(lod0)): + new_lod0[i] = lod0[i] - num_erased[lod0[i]] + + out_seq = np.zeros( + (len(in_seq) - num_erased[len(in_seq)], 1)).astype("int32") + for i in range(0, len(in_seq)): + if num_erased[i] == num_erased[i + 1]: + out_seq[i - num_erased[i]] = in_seq[i] + # else in_seq[i] needs to be removed + return out_seq, new_lod0 + + +class TestSequenceEraseOp(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 5, 15, 30]] + tokens = [2, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + """ + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod0 = [0, 5, 15, 30] + tokens = [2, 5] + out_seq, new_lod = sequence_erase(in_seq, lod0, tokens) + + print lod0, new_lod + print("compare") + for i in range(0, len(lod0)-1): + print(np.transpose(in_seq[lod0[i] : lod0[i+1]])) + print(np.transpose(out_seq[new_lod[i] : new_lod[i+1]])) + print("\n") + """ + unittest.main() -- GitLab