提交 7b9d5b32 编写于 作者: Y Yibing Liu

Add document for sequence_erase_op

上级 37f933b8
...@@ -26,7 +26,11 @@ class SequenceEraseOp : public framework::OperatorWithKernel { ...@@ -26,7 +26,11 @@ class SequenceEraseOp : public framework::OperatorWithKernel {
"Input(X) of SequenceEraseOp should not be null."); "Input(X) of SequenceEraseOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceEraseOp should not be null."); "Output(Out) of SequenceEraseOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1,
"Input(X) of SequenceEraseOp should be a 2-D LoDTensor "
"with the 2nd dimension equal to 1.");
ctx->SetOutputDim("Out", x_dims);
} }
}; };
...@@ -35,17 +39,41 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -35,17 +39,41 @@ class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor) 2-D input LoDTensor with the 2-nd dimension " "(2-D LoDTensor with the 2nd dim. equal to 1) "
"of length 1."); "Input LoDTensor of SequenceEraseOp.");
AddOutput("Out", AddOutput("Out",
"(LoDTensor) 2-D output LoDTensor with the 2-nd dimension " "(2-D LoDTensor with the 2nd dim. equal to 1) "
"of length 1."); "Output LoDTensor of SequenceEraseOp.");
AddAttr<std::vector<int>>("tokens", AddAttr<std::vector<int>>("tokens",
"(vector<int>) " "(vector<int>) Tokens need to be erased from "
"Tokens to be removed from input."); "input sequences.");
AddComment(R"DOC( AddComment(R"DOC(
Sequence Erase Operator. Sequence Erase Operator.
Sequence erase operator erases tokens specified by Attr(tokens) in the input
sequences Input(X), and outputs the remaining data and modifies the LoD
information at the same time. For example, given a 2-D LoDTensor
X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T
with lod = [[0, 3, 6, 10]], there are three sequences in the input:
X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T.
If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing
operation, the three sequences become
X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T.
Hence the LoDTensor Output(Out) should be
Out = [[6, 1, 9, 6, 1, 0, 1]]^T,
with lod = [[0, 1, 3, 7]].
An example usage for this operator is to remove the special tokens when
computing the edit distance between two strings, such as blank, start token,
and end token.
)DOC"); )DOC");
} }
}; };
......
...@@ -13,17 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include "paddle/operators/sequence_erase_op.h" #include "paddle/operators/sequence_erase_op.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS; using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
...@@ -97,7 +93,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -97,7 +93,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
num_erased.begin() + 1); num_erased.begin() + 1);
// Reset LoD // Calc LoD
auto lod_len = lod0.size(); auto lod_len = lod0.size();
thrust::host_vector<int> host_lod(lod_len); thrust::host_vector<int> host_lod(lod_len);
for (size_t i = 0; i < lod_len; ++i) { for (size_t i = 0; i < lod_len; ++i) {
...@@ -117,15 +113,14 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -117,15 +113,14 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
} }
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); out_lod.push_back(out_lod0);
out->set_lod(out_lod);
out->Resize({out_lod0.back(), 1});
// Set output // Set output
out->Resize({out_lod0.back(), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace()); auto out_dat = out->mutable_data<T>(ctx.GetPlace());
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
num_erased_ptr, out_dat); num_erased_ptr, out_dat);
// Set LoD
out->set_lod(out_lod);
} }
}; };
......
...@@ -15,14 +15,10 @@ limitations under the License. */ ...@@ -15,14 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceEraseKernel : public framework::OpKernel<T> { class SequenceEraseKernel : public framework::OpKernel<T> {
public: public:
...@@ -32,17 +28,6 @@ class SequenceEraseKernel : public framework::OpKernel<T> { ...@@ -32,17 +28,6 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
auto lod = in->lod(); auto lod = in->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
// auto dims = x->dims();
/*
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
"The first dimension of Input(X) should be equal to the "
"sum of all sequences' lengths.");
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
"The width of each timestep in Input(X) of "
"SequenceEraseOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
*/
auto tokens = ctx.Attr<std::vector<int>>("tokens"); auto tokens = ctx.Attr<std::vector<int>>("tokens");
auto in_len = in->numel(); auto in_len = in->numel();
auto in_dat = in->data<T>(); auto in_dat = in->data<T>();
...@@ -65,7 +50,7 @@ class SequenceEraseKernel : public framework::OpKernel<T> { ...@@ -65,7 +50,7 @@ class SequenceEraseKernel : public framework::OpKernel<T> {
out->Resize({static_cast<int64_t>(out_len), 1}); out->Resize({static_cast<int64_t>(out_len), 1});
auto out_dat = out->mutable_data<T>(ctx.GetPlace()); auto out_dat = out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < in_len; ++i) { for (int64_t i = 0; i < in_len; ++i) {
if (num_erased[i] == num_erased[i + 1]) { if (num_erased[i] == num_erased[i + 1]) {
out_dat[i - num_erased[i]] = in_dat[i]; out_dat[i - num_erased[i]] = in_dat[i];
} }
......
...@@ -28,9 +28,9 @@ def sequence_erase(in_seq, lod0, tokens): ...@@ -28,9 +28,9 @@ def sequence_erase(in_seq, lod0, tokens):
class TestSequenceEraseOp(OpTest): class TestSequenceEraseOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sequence_erase" self.op_type = "sequence_erase"
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") in_seq = np.random.randint(0, 10, (10, 1)).astype("int32")
lod = [[0, 5, 15, 30]] lod = [[0, 3, 6, 10]]
tokens = [2, 5] tokens = [2, 3, 5]
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
self.attrs = {'tokens': tokens} self.attrs = {'tokens': tokens}
self.inputs = {'X': (in_seq, lod)} self.inputs = {'X': (in_seq, lod)}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册