提交 cbe42925 编写于 作者: Y Yibing Liu

Add sequence unpad op

test=develop
上级 5f2e8378
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_unpad_op.h"
namespace paddle {
namespace operators {
class SequenceUnpadOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceUnpadOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Length"),
"Input(Length) of SequenceUnpadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceUnpadOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The rank of Input(X) can't be less than 2.");
auto len_dims = ctx->GetInputDim("Length");
PADDLE_ENFORCE(len_dims.size() == 2 && len_dims[1] == 1,
"The shape of Input(Length) should be [batch_size, 1].");
PADDLE_ENFORCE(
len_dims[0] == x_dims[0],
"Input(X) and Input(Length) should have the same first dimension.");
int64_t out_dim_0 = -1;
if (ctx->IsRuntime()) {
out_dim_0 = x_dims[0] * x_dims[1];
}
std::vector<int64_t> out_dims_vec{out_dim_0};
if (x_dims.size() == 2) {
out_dims_vec.push_back(1);
} else {
for (size_t i = 2; i < x_dims.size(); ++i) {
out_dims_vec.push_back(x_dims[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SequenceUnpadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) Input tensor which "
"contains the padded sequences with equal length.");
AddInput("Length",
"(LoDTensor) The input tensor which specifies the actual ength of "
"sequences after unpadding.");
AddOutput(
"Out",
"(LoDTensor) The output tensor which contains unpadded sequences.");
AddComment(R"DOC(
Sequence Unpad Operator
This operator removes the padding data in the input sequences and convert
them into sequences with actual length as output, identitied by lod
information.
Example:
Given input tensor Input(X):
X.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0],
[ 6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]],
`
in which there are 3 sequences padded to length 5, and the acutal length
specified by Input(Length):
Length.data = [[2], [3], [4]],
after unpadding, Output(Out) will be:
Out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]]
Out.lod = [[0, 2, 5, 9]]
)DOC");
}
};
class SequenceUnpadGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceUnpadGradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequenceUnpadGradOp should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp,
ops::SequenceUnpadOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
sequence_unpad_grad,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext,
int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_unpad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_unpad_grad,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename DeviceContext, typename T>
class SequenceUnpadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_t = ctx.Input<LoDTensor>("X");
auto* len_t = ctx.Input<LoDTensor>("Length");
auto* out_t = ctx.Output<LoDTensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
const int64_t* seq_len_ptr = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) {
LoDTensor seq_len_cpu;
seq_len_cpu.Resize(len_t->dims());
seq_len_ptr = seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
framework::TensorCopy(*len_t, platform::CPUPlace(),
ctx.template device_context<DeviceContext>(),
&seq_len_cpu);
} else {
seq_len_ptr = len_t->data<int64_t>();
}
size_t batch_size = x_t->dims()[0];
std::vector<size_t> out_lod0(batch_size + 1, 0);
for (size_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i];
}
framework::LoD out_lod;
out_lod.push_back(out_lod0);
out_t->set_lod(out_lod);
std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())};
if (x_t->dims().size() == 2) {
out_dims_vec.push_back(1);
} else {
for (size_t i = 2; i < x_t->dims().size(); ++i) {
out_dims_vec.push_back(x_t->dims()[i]);
}
}
out_t->Resize(framework::make_ddim(out_dims_vec));
int64_t padded_length = x_t->dims()[1];
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x_t, out_t,
padded_length, 0, false, math::kBatchLengthWidth);
}
};
template <typename DeviceContext, typename T>
class SequenceUnpadGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x) {
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
const auto* x_t = ctx.Input<LoDTensor>("X");
d_x->mutable_data<T>(ctx.GetPlace());
int padded_length = x_t->dims()[1];
LoDTensor zero_pads;
zero_pads.Resize({1, 1});
zero_pads.mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
set_zero(dev_ctx, &zero_pads, static_cast<T>(0));
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x, zero_pads,
padded_length, 0, false, math::kBatchLengthWidth);
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import six
import numpy as np
from op_test import OpTest
class TestSequenceUnpadOp(OpTest):
def init(self):
self.length = [2, 3, 4]
self.x_shape = (3, 5)
self.dtype = "float32"
def compute(self):
assert len(self.length) == self.x_shape[0]
x = np.random.random(self.x_shape).astype(self.dtype)
out_lod = [self.length]
out = x[0, 0:self.length[0]]
for i in six.moves.xrange(1, x.shape[0]):
out = np.append(out, x[i, 0:self.length[i]], axis=0)
out_shape = (sum(self.length), )
if len(self.x_shape) == 2:
out_shape = out_shape + (1, )
else:
out_shape = out_shape + self.x_shape[2:]
self.inputs = {
'X': x,
'Length': np.array(self.length).astype('int64').reshape(-1, 1)
}
self.outputs = {'Out': (out.reshape(out_shape), out_lod)}
def setUp(self):
self.op_type = 'sequence_unpad'
self.init()
self.compute()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSequenceUnpadOp2(TestSequenceUnpadOp):
def init(self):
self.length = [2, 3, 4]
self.x_shape = (3, 5, 4, 3)
self.dtype = "float32"
class TestSequenceUnpadOp3(TestSequenceUnpadOp):
def init(self):
self.length = [5, 2, 3, 4]
self.x_shape = (4, 5, 3, 3, 6)
self.dtype = "float64"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册