diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d17b2686238b2d2f872331edfdbb095fb8693b87 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_erase_op.h" + +namespace paddle { +namespace operators { + +class SequenceEraseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceEraseOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceEraseOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1, + "Input(X) of SequenceEraseOp should be a 2-D LoDTensor " + "with the 2nd dimension equal to 1."); + ctx->SetOutputDim("Out", x_dims); + } +}; + +class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Input LoDTensor of SequenceEraseOp."); + AddOutput("Out", + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Output LoDTensor of SequenceEraseOp."); + AddAttr>("tokens", + "(vector) Tokens need to be erased from " + "input sequences."); + AddComment(R"DOC( +Sequence Erase Operator. + +Sequence erase operator erases tokens specified by Attr(tokens) from the input +sequences Input(X), and outputs the remaining data and modifies the LoD +information at the same time. For example, given a 2-D LoDTensor + + X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T + +with lod = [[0, 3, 6, 10]], there are three sequences in the input: + + X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T. + +If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing +operation, the three sequences become + + X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T. + +Hence the LoDTensor Output(Out) should be + + Out = [[6, 1, 9, 6, 1, 0, 1]]^T, + +with lod = [[0, 1, 3, 7]]. + +An example usage for this operator is to remove the special tokens when +computing the edit distance between two strings, such as blank, start token, +and end token. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, + ops::SequenceEraseOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_erase, + ops::SequenceEraseKernel); diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5da8eba3e1ac1fb85dfc65c2fd801574599e02d9 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cu @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/operators/sequence_erase_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +__global__ void LabelErasedIdx(const T* in_dat, const int in_len, + const T* tokens, const int tokens_len, + int* num_erased) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + int erased = 0; + for (int i = 0; i < tokens_len; ++i) { + if (in_dat[index] == tokens[i]) { + erased = 1; + } + } + num_erased[index + 1] = erased; + if (index == 0) { + num_erased[0] = 0; + } + } +} + +template +__global__ void GetOutLod(const T* num_erased, const int* in_lod, + const int lod_len, int* out_lod0) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < lod_len) { + out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; + } +} + +template +__global__ void SetOutput(const T* in_dat, const int in_len, + const int* num_erased, T* out_dat) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + if (in_dat[index] != in_dat[index + 1]) { + out_dat[index - num_erased[index]] = in_dat[index]; + } + } +} + +template +class SequenceEraseOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); + auto tokens = ctx.Attr>("tokens"); + auto tokens_len = tokens.size(); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + thrust::host_vector host_tokens(tokens_len); + for (size_t i = 0; i < tokens.size(); ++i) { + host_tokens[i] = tokens[i]; + } + thrust::device_vector dev_tokens = host_tokens; + thrust::device_vector num_erased(in_len + 1); + + T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); + int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); + + auto stream = ctx.cuda_device_context().stream(); + LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr); + thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), + num_erased.begin() + 1); + + // Calc LoD + auto lod_len = lod0.size(); + thrust::host_vector host_lod(lod_len); + for (size_t i = 0; i < lod_len; ++i) { + host_lod[i] = lod0[i]; + } + thrust::device_vector dev_in_lod = host_lod; + thrust::device_vector dev_out_lod(lod_len); + int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); + int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); + GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); + thrust::host_vector host_out_lod = dev_out_lod; + std::vector out_lod0(lod_len, 0); + for (size_t i = 0; i < lod_len; i++) { + out_lod0[i] = host_out_lod[i]; + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + + // Set output + out->Resize({out_lod0.back(), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, + num_erased_ptr, out_dat); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(sequence_erase, + paddle::operators::SequenceEraseOpCUDAKernel); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cb2d7be009dcbe0138818457249e95fbdd27fc0a --- /dev/null +++ b/paddle/operators/sequence_erase_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SequenceEraseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); + auto tokens = ctx.Attr>("tokens"); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + std::vector num_erased(in_len + 1, 0); + std::vector out_lod0(1, 0); + for (size_t i = 0; i < lod0.size() - 1; ++i) { + size_t num_out = 0; + for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) { + num_erased[j] = num_erased[j - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) != + tokens.end()) { + num_erased[j] += 1; + } else { + num_out += 1; + } + } + out_lod0.push_back(out_lod0.back() + num_out); + } + + auto out_len = in_len - num_erased[in_len]; + out->Resize({static_cast(out_len), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + + for (int64_t i = 0; i < in_len; ++i) { + if (num_erased[i] == num_erased[i + 1]) { + out_dat[i - num_erased[i]] = in_dat[i]; + } + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bf257fefea0d98c6f4d9860dbac4ccedf59bcdd9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_erase(in_seq, lod0, tokens): + new_lod0 = [0] + out_seq = [] + for i in range(0, len(lod0) - 1): + num_out = 0 + for dat in in_seq[lod0[i]:lod0[i + 1]]: + if dat not in tokens: + out_seq.append(dat) + num_out += 1 + new_lod0.append(new_lod0[-1] + num_out) + return np.array(out_seq).astype("int32"), new_lod0 + + +class TestSequenceEraseOp(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 9, 13, 24, 30]] + tokens = [2, 3, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()