未验证 提交 59bf85d9 编写于 作者: W whs 提交者: GitHub

Merge pull request #7325 from kuke/sequence_erase_op

Add sequence erase op
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_erase_op.h"
namespace paddle {
namespace operators {
class SequenceEraseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceEraseOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceEraseOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1,
"Input(X) of SequenceEraseOp should be a 2-D LoDTensor "
"with the 2nd dimension equal to 1.");
ctx->SetOutputDim("Out", x_dims);
}
};
class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(2-D LoDTensor with the 2nd dim. equal to 1) "
"Input LoDTensor of SequenceEraseOp.");
AddOutput("Out",
"(2-D LoDTensor with the 2nd dim. equal to 1) "
"Output LoDTensor of SequenceEraseOp.");
AddAttr<std::vector<int>>("tokens",
"(vector<int>) Tokens need to be erased from "
"input sequences.");
AddComment(R"DOC(
Sequence Erase Operator.
Sequence erase operator erases tokens specified by Attr(tokens) from the input
sequences Input(X), and outputs the remaining data and modifies the LoD
information at the same time. For example, given a 2-D LoDTensor
X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T
with lod = [[0, 3, 6, 10]], there are three sequences in the input:
X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T.
If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing
operation, the three sequences become
X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T.
Hence the LoDTensor Output(Out) should be
Out = [[6, 1, 9, 6, 1, 0, 1]]^T,
with lod = [[0, 1, 3, 7]].
An example usage for this operator is to remove the special tokens when
computing the edit distance between two strings, such as blank, start token,
and end token.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
ops::SequenceEraseOpMaker);
REGISTER_OP_CPU_KERNEL(
sequence_erase,
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/sequence_erase_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;
template <typename T>
__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
const T* tokens, const int tokens_len,
int* num_erased) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) {
int erased = 0;
for (int i = 0; i < tokens_len; ++i) {
if (in_dat[index] == tokens[i]) {
erased = 1;
}
}
num_erased[index + 1] = erased;
if (index == 0) {
num_erased[0] = 0;
}
}
}
template <typename T>
__global__ void GetOutLod(const T* num_erased, const int* in_lod,
const int lod_len, int* out_lod0) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < lod_len) {
out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
}
}
template <typename T>
__global__ void SetOutput(const T* in_dat, const int in_len,
const int* num_erased, T* out_dat) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) {
if (in_dat[index] != in_dat[index + 1]) {
out_dat[index - num_erased[index]] = in_dat[index];
}
}
}
template <typename T>
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<T>>("tokens");
auto tokens_len = tokens.size();
auto in_len = in->numel();
auto in_dat = in->data<T>();
auto lod0 = lod[0];
thrust::host_vector<T> host_tokens(tokens_len);
for (size_t i = 0; i < tokens.size(); ++i) {
host_tokens[i] = tokens[i];
}
thrust::device_vector<T> dev_tokens = host_tokens;
thrust::device_vector<int> num_erased(in_len + 1);
T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
auto stream = ctx.cuda_device_context().stream();
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr);
thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
num_erased.begin() + 1);
// Calc LoD
auto lod_len = lod0.size();
thrust::host_vector<int> host_lod(lod_len);
for (size_t i = 0; i < lod_len; ++i) {
host_lod[i] = lod0[i];
}
thrust::device_vector<int> dev_in_lod = host_lod;
thrust::device_vector<int> dev_out_lod(lod_len);
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
thrust::host_vector<int> host_out_lod = dev_out_lod;
std::vector<int> out_lod0(lod_len, 0);
for (size_t i = 0; i < lod_len; i++) {
out_lod0[i] = host_out_lod[i];
}
framework::LoD out_lod;
out_lod.push_back(out_lod0);
out->set_lod(out_lod);
// Set output
out->Resize({out_lod0.back(), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
num_erased_ptr, out_dat);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(sequence_erase,
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SequenceEraseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::LoDTensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
"The actual size mismatches with the LoD information.");
auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel();
auto in_dat = in->data<T>();
auto lod0 = lod[0];
std::vector<size_t> num_erased(in_len + 1, 0);
std::vector<size_t> out_lod0(1, 0);
for (size_t i = 0; i < lod0.size() - 1; ++i) {
size_t num_out = 0;
for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) {
num_erased[j] = num_erased[j - 1];
if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) !=
tokens.end()) {
num_erased[j] += 1;
} else {
num_out += 1;
}
}
out_lod0.push_back(out_lod0.back() + num_out);
}
auto out_len = in_len - num_erased[in_len];
out->Resize({static_cast<int64_t>(out_len), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
for (int64_t i = 0; i < in_len; ++i) {
if (num_erased[i] == num_erased[i + 1]) {
out_dat[i - num_erased[i]] = in_dat[i];
}
}
framework::LoD out_lod;
out_lod.push_back(out_lod0);
out->set_lod(out_lod);
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def sequence_erase(in_seq, lod0, tokens):
new_lod0 = [0]
out_seq = []
for i in range(0, len(lod0) - 1):
num_out = 0
for dat in in_seq[lod0[i]:lod0[i + 1]]:
if dat not in tokens:
out_seq.append(dat)
num_out += 1
new_lod0.append(new_lod0[-1] + num_out)
return np.array(out_seq).astype("int32"), new_lod0
class TestSequenceEraseOp(OpTest):
def setUp(self):
self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
lod = [[0, 9, 13, 24, 30]]
tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)}
self.outputs = {'Out': (out_seq, [new_lod0])}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册