未验证 提交 46b0b790 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #13856 from kuke/seq_unpad_op

Add sequence unpad op
......@@ -75,7 +75,8 @@ paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'outp
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.sequence_expand_as ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.sequence_unpad ArgSpec(args=['x', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_unpad_op.h"
namespace paddle {
namespace operators {
class SequenceUnpadOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceUnpadOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Length"),
"Input(Length) of SequenceUnpadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceUnpadOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The rank of Input(X) can't be less than 2.");
auto len_dims = ctx->GetInputDim("Length");
PADDLE_ENFORCE(len_dims.size() == 2 && len_dims[1] == 1,
"The shape of Input(Length) should be [batch_size, 1].");
PADDLE_ENFORCE(
len_dims[0] == x_dims[0],
"Input(X) and Input(Length) should have the same first dimension.");
int64_t out_dim_0 = -1;
if (ctx->IsRuntime()) {
out_dim_0 = x_dims[0] * x_dims[1];
}
std::vector<int64_t> out_dims_vec{out_dim_0};
if (x_dims.size() == 2) {
out_dims_vec.push_back(1);
} else {
for (size_t i = 2; i < x_dims.size(); ++i) {
out_dims_vec.push_back(x_dims[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SequenceUnpadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) Input tensor which "
"contains the padded sequences with equal length.");
AddInput("Length",
"(LoDTensor) The input tensor which specifies the actual ength of "
"sequences after unpadding.");
AddOutput(
"Out",
"(LoDTensor) The output tensor which contains unpadded sequences.");
AddComment(R"DOC(
Sequence Unpad Operator
This operator removes the padding data in the input sequences and convert
them into sequences with actual length as output, identitied by lod
information.
Example:
Given input tensor Input(X):
X.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0],
[ 6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]],
`
in which there are 3 sequences padded to length 5, and the acutal length
specified by Input(Length):
Length.data = [[2], [3], [4]],
after unpadding, Output(Out) will be:
Out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]]
Out.lod = [[0, 2, 5, 9]]
)DOC");
}
};
class SequenceUnpadGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceUnpadGradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequenceUnpadGradOp should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp,
ops::SequenceUnpadOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
sequence_unpad_grad,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CPUDeviceContext,
int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_unpad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceUnpadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_unpad_grad,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceUnpadGradOpKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename DeviceContext, typename T>
class SequenceUnpadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_t = ctx.Input<LoDTensor>("X");
auto* len_t = ctx.Input<LoDTensor>("Length");
auto* out_t = ctx.Output<LoDTensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
const int64_t* seq_len_ptr = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) {
LoDTensor seq_len_cpu;
seq_len_cpu.Resize(len_t->dims());
seq_len_ptr = seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
framework::TensorCopy(*len_t, platform::CPUPlace(),
ctx.template device_context<DeviceContext>(),
&seq_len_cpu);
} else {
seq_len_ptr = len_t->data<int64_t>();
}
size_t batch_size = x_t->dims()[0];
std::vector<size_t> out_lod0(batch_size + 1, 0);
for (size_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i];
}
framework::LoD out_lod;
out_lod.push_back(out_lod0);
out_t->set_lod(out_lod);
std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())};
if (x_t->dims().size() == 2) {
out_dims_vec.push_back(1);
} else {
for (size_t i = 2; i < x_t->dims().size(); ++i) {
out_dims_vec.push_back(x_t->dims()[i]);
}
}
out_t->Resize(framework::make_ddim(out_dims_vec));
int64_t padded_length = x_t->dims()[1];
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x_t, out_t,
padded_length, 0, false, math::kBatchLengthWidth);
}
};
template <typename DeviceContext, typename T>
class SequenceUnpadGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x) {
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
const auto* x_t = ctx.Input<LoDTensor>("X");
d_x->mutable_data<T>(ctx.GetPlace());
int padded_length = x_t->dims()[1];
LoDTensor zero_pads;
zero_pads.Resize({1, 1});
zero_pads.mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
set_zero(dev_ctx, &zero_pads, static_cast<T>(0));
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x, zero_pads,
padded_length, 0, false, math::kBatchLengthWidth);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -56,6 +56,7 @@ __all__ = [
'sequence_expand',
'sequence_expand_as',
'sequence_pad',
'sequence_unpad',
'lstm_unit',
'reduce_sum',
'reduce_mean',
......@@ -2793,7 +2794,7 @@ def sequence_expand_as(x, y, name=None):
@templatedoc()
def sequence_pad(x, pad_value, maxlen=None):
def sequence_pad(x, pad_value, maxlen=None, name=None):
"""
${comment}
......@@ -2807,7 +2808,9 @@ def sequence_pad(x, pad_value, maxlen=None):
None or any positive int. When it is None, all sequences will be
padded up to the length of the longest one among them; when it a
certain positive value, it must be greater than the length of the
longest original sequence."
longest original sequence.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The padded sequence batch and the original lengths before
......@@ -2844,6 +2847,66 @@ def sequence_pad(x, pad_value, maxlen=None):
return out, length
def sequence_unpad(x, length, name=None):
"""
**Sequence Unpad Layer**
This layer removes the padding data in the input sequences and convert
them into sequences with actual length as output, identitied by lod
information.
.. code-block:: text
Example:
Given input Variable **x**:
x.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0],
[ 6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]],
in which there are 3 sequences padded to length 5, and the acutal length
specified by input Variable **length**:
length.data = [[2], [3], [4]],
after unpadding, the output Variable will be:
out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]]
out.lod = [[2, 3, 4]]
Args:
x(Variable): Input Variable which contains the padded sequences with
equal length.
length(Variable): The Variable that specifies the actual ength of
sequences after unpadding.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The Variable contains the unpadded sequences.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10, 5], dtype='float32')
len = fluid.layers.data(name='length', shape=[1], dtype='int64')
out = fluid.layers.sequence_unpad(x=x, length=len)
"""
helper = LayerHelper('sequence_unpad', input=x, **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
length.stop_gradient = True
helper.append_op(
type='sequence_unpad',
inputs={'X': x,
'Length': length},
outputs={'Out': out})
return out
def beam_search(pre_ids,
pre_scores,
ids,
......
......@@ -194,6 +194,14 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1))
print(str(program))
def test_sequence_unpad(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[10, 5], dtype='float32')
length = layers.data(name='length', shape=[1], dtype='int64')
self.assertIsNotNone(layers.sequence_unpad(x=x, length=length))
print(str(program))
def test_lstm_unit(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import six
import numpy as np
from op_test import OpTest
class TestSequenceUnpadOp(OpTest):
def init(self):
self.length = [2, 3, 4]
self.x_shape = (3, 5)
self.dtype = "float32"
def compute(self):
assert len(self.length) == self.x_shape[0]
x = np.random.random(self.x_shape).astype(self.dtype)
out_lod = [self.length]
out = x[0, 0:self.length[0]]
for i in six.moves.xrange(1, x.shape[0]):
out = np.append(out, x[i, 0:self.length[i]], axis=0)
out_shape = (sum(self.length), )
if len(self.x_shape) == 2:
out_shape = out_shape + (1, )
else:
out_shape = out_shape + self.x_shape[2:]
self.inputs = {
'X': x,
'Length': np.array(self.length).astype('int64').reshape(-1, 1)
}
self.outputs = {'Out': (out.reshape(out_shape), out_lod)}
def setUp(self):
self.op_type = 'sequence_unpad'
self.init()
self.compute()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSequenceUnpadOp2(TestSequenceUnpadOp):
def init(self):
self.length = [2, 3, 4]
self.x_shape = (3, 5, 4, 3)
self.dtype = "float32"
class TestSequenceUnpadOp3(TestSequenceUnpadOp):
def init(self):
self.length = [5, 2, 3, 4]
self.x_shape = (4, 5, 3, 3, 6)
self.dtype = "float64"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册