未验证 提交 3b6090e8 编写于 作者: C Chen Weihang 提交者: GitHub

Merge pull request #12887 from chenwhql/sequence_enumerate_op

Feat: add sequence enumerate op
......@@ -172,6 +172,7 @@ paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'],
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None))
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_enumerate_op.h"
namespace paddle {
namespace operators {
class SequenceEnumerateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(X) of SequenceEnumerate operator should not be null.");
const auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
x_dims.size(), 2UL,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], 1UL,
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size});
ctx->ShareLoD("X", "Out");
}
};
class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(2-D LoDTensor with the 2nd dimension equal to 1) "
"Input LoDTensor of SequenceEnumerate operator.");
AddOutput("Out",
"(2-D LoDTensor with the 2nd dimension equal to win_size) "
"Output LoDTensor of SequenceEnumerate operator.");
AddAttr<int>("win_size", "(int) The enumerate sequence window size.")
.AddCustomChecker([](const int& win_size) {
PADDLE_ENFORCE(win_size >= 2,
"The window size should be not less than 2.");
});
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0);
AddComment(R"DOC(
Sequence Enumerate Operator.
Generate a new sequence for the input index sequence, which enumerates all the
sub-sequences with length `win_size` of the input.
The enumerated sequence has the same 1st dimension with variable `input`, and
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp,
ops::SequenceEnumerateOpMaker);
REGISTER_OP_CPU_KERNEL(
sequence_enumerate,
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/sequence_enumerate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;
template <typename T>
__global__ void CalcOutPut(const T* in_data, const size_t* in_lod,
const size_t lod_len, const int64_t win_size,
const int64_t pad_value, T* out_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_lod[lod_len - 1]) {
int end_idx = 0;
// Get LoD interval of index
for (int i = 1; i < lod_len; ++i) {
if (index < in_lod[i]) {
end_idx = in_lod[i];
break;
}
}
for (size_t i = 0; i < win_size; ++i) {
int word_pos = index + i;
out_data[index * win_size + i] =
word_pos < end_idx ? in_data[word_pos] : pad_value;
}
}
}
template <typename T>
class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int win_size = context.Attr<int>("win_size");
int pad_value = context.Attr<int>("pad_value");
auto in_dims = in->dims();
auto in_lod = in->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");
/* Generate enumerate sequence set */
auto stream = context.cuda_device_context().stream();
auto lod0 = in_lod[0];
auto in_len = in->numel();
auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
// Calc output tensor
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
sequence_enumerate,
paddle::operators::SequenceEnumerateOpCUDAKernel<int32_t>,
paddle::operators::SequenceEnumerateOpCUDAKernel<int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class SequenceEnumerateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int win_size = context.Attr<int>("win_size");
int pad_value = context.Attr<int>("pad_value");
auto in_dims = in->dims();
auto in_lod = in->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");
// Generate enumerate sequence set
auto lod0 = in_lod[0];
auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
size_t word_pos = idx + word_idx;
out_data[win_size * idx + word_idx] =
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -111,6 +111,7 @@ __all__ = [
'stack',
'pad2d',
'unstack',
'sequence_enumerate',
]
......@@ -5823,6 +5824,51 @@ def flatten(x, axis=1, name=None):
return out
def sequence_enumerate(input, win_size, pad_value=0, name=None):
"""
Generate a new sequence for the input index sequence, which enumerates all the
sub-sequences with length `win_size` of the input.
The enumerated sequence has the same 1st dimension with variable `input`, and
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
Args:
input (Variable): The input variable which is a index sequence.
win_size (int): The window size for enumerating all sub-sequences.
pad_value (int): The padding value, default 0.
Returns:
Variable: The enumerate sequence variable which is a LoDTensor.
Examples:
.. code-block:: python
x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1)
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
"""
helper = LayerHelper('sequence_enumerate', **locals())
out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='sequence_enumerate',
inputs={'X': input},
outputs={'Out': out},
attrs={'win_size': win_size,
'pad_value': pad_value})
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
"""
**SequenceMask Layer**
......@@ -5902,6 +5948,7 @@ def stack(x, axis=0):
helper.append_op(
type='stack', inputs={'X': x}, outputs={'Y': out},
attrs={'axis': axis})
return out
......
......@@ -549,6 +549,13 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_sequence_enumerate(self):
program = Program()
with program_guard(program):
x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1)
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
print(str(program))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
def sequence_enumerate(input_seq, in_lod, win_size, pad_value):
lod0 = [0]
for i in range(0, len(in_lod[0])):
lod0.append(lod0[i] + in_lod[0][i])
out_seq = []
for i in range(0, len(lod0) - 1):
for idx in range(lod0[i], lod0[i + 1]):
single_seq = []
for word_idx in range(win_size):
word_pos = idx + word_idx
dat = input_seq[word_pos] if word_pos < lod0[i+1] \
else pad_value
single_seq.append(dat)
out_seq.append(single_seq)
return out_seq
class TestSequenceEnumerateOp(OpTest):
def setUp(self):
self.op_type = "sequence_enumerate"
self.init_test_case()
self.inputs = {'X': (self.in_seq, self.lod)}
self.attrs = {'win_size': self.win_size, 'pad_value': self.pad_value}
self.outputs = {'Out': (self.out_seq, self.lod)}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int64")
class TestSequenceEnumerateOpLargeWinSize(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 5
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 30
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpLargePadValue(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 5
self.pad_value = 5
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册