提交 8d8d48a3 编写于 作者: F fengjiayi

Complete sequence_pad_op and its CPU kernel. Add unittests

上级 3c749fae
......@@ -70,9 +70,10 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
std::vector<T> pad_value = {0}, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor.lod())[lod_level];
auto seq_tensor_dims = seq_tensor.dims();
auto pad_tensor_dims = pad_tensor->dims();
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
const auto& seq_tensor_dims = seq_tensor.dims();
const auto& pad_tensor_dims = pad_tensor->dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
......@@ -91,12 +92,21 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
// fill padding value
T* pad_data = pad_tensor->data<T>();
for (int i = 0; i < pad_tensor->numel() / step_width; ++i) {
memcpy(pad_data, pad_value.data(), step_width * sizeof(T));
for (int i = 0; i < pad_tensor->numel(); i += step_width) {
memcpy(pad_data + i, pad_value.data(), step_width * sizeof(T));
}
CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kSeqToPad, layout);
// Set pad_tensor's lod info if possible
if (layout == kBatchLengthWidth) {
framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end());
for (size_t i = 0; i < pad_lod[0].size(); ++i) {
pad_lod[0][i] = i * pad_seq_len;
}
pad_tensor->set_lod(pad_lod);
}
}
};
......@@ -109,8 +119,8 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
int lod_level = 0, bool norm_by_times = false,
const PadLayout& layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
auto seq_tensor_dims = seq_tensor->dims();
auto pad_tensor_dims = pad_tensor.dims();
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
......
......@@ -44,9 +44,6 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims,
"Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences.");
PADDLE_ENFORCE(seq_tensor_dims.size() == 1 || seq_tensor_dims.size() == 2,
"seq_tensor's rank should be 1 or 2.");
PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() ||
seq_tensor_dims.size() == pad_tensor_dims.size(),
"pad_tensor's rank should be 1 greater than seq_tensor's "
......
......@@ -21,82 +21,85 @@ class SequencePadOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePadOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PadValue"),
"Input(PadValue) of SequencePadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequencePadOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The rank of Input(x) can't be less than 2.");
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
auto pad_value_dims = ctx->GetInputDim("PadValue");
PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) ||
pad_value_dims == time_step_dims,
"The Input(PadValue) must be a scalar or a tensor whose "
"shape equals to time steps in sequences");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"Only support 2-D tensor, rank of Input(X) should be 2.");
int lod_level = ctx->Attrs().Get<int>("lod_level");
int64_t max_len = -1;
int64_t seq_num = -1;
int x_lod_size = -1;
int batch_dim_size = -1;
if (ctx->IsRuntime()) {
// run time
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
auto& x_lod = x_var->Get<LoDTensor>().lod();
x_lod_size = x_lod.size();
auto x_abs_offset = framework::ToAbsOffset(x_lod)[lod_level];
PADDLE_ENFORCE_EQ(x_dims[0], static_cast<int64_t>(x_abs_offset.back()),
"The first dimension of `X` should be equal to sum "
"of all sequences' length.");
seq_num = x_abs_offset.size() - 1;
for (int64_t i = 1; i <= seq_num; ++i) {
int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1];
max_len = max_len < seq_len ? seq_len : max_len;
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
const auto& x_lod_0 = x_lod[0];
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
"The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
x_dims[0], static_cast<int64_t>(x_lod_0.back()),
"The Input(X)'s lod info mismatches the actual tensor shape.");
int seq_num = x_lod_0.size() - 1;
int max_seq_len = math::MaximumSequenceLength(x_lod_0);
int padded_length = ctx->Attrs().Get<int>("padded_length");
if (padded_length == -1) {
padded_length = max_seq_len;
}
PADDLE_ENFORCE_GE(padded_length, max_seq_len,
"The Attr(padded_length) must be -1 or an int greater "
"than the length of the longest original sequence.");
batch_dim_size = padded_length * seq_num;
} else {
// compile time
framework::VarDesc* x_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]);
x_lod_size = x_desc->GetLoDLevel();
PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1);
}
PADDLE_ENFORCE(lod_level >= 0 && lod_level < x_lod_size,
"Invalid `lod_level` which should be at least 0 and less "
"than maximum lod level of `X`");
ctx->SetOutputDim("Out", {seq_num, max_len, x_dims[1]});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
auto out_dims = x_dims;
out_dims[0] = batch_dim_size;
ctx->SetOutputDim("Out", out_dims);
}
};
class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequencePadOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information. Length of each sequence would "
"be computed from the most bottom level lod.");
AddOutput("Out",
"(Tensor) Output variable which would be a common tensor "
"without lod. Each sequence would be padded to the maximum "
"length.");
AddAttr<float>("lod_level",
"(int, default 0) Specify which level lod to referred to.");
AddAttr<float>("pad_value",
"(float, default 0.0) Specify which value to be padded to "
"the end of each sequence.");
"should contain lod information.");
AddInput("PadValue",
"(LoDTensor), this Tensor holds values that will be fill into "
"padded steps. It can be a scalar or a tensor whose shape equals "
"to time steps in sequences. If it's a scalar, it will be "
"automatically broadcasted to the shape of time step.");
AddOutput(
"Out",
"(LoDTensor) The output vairable, which contains padded sequences.");
AddAttr<int>(
"padded_length",
"The length of padded sequences. It can be setted to -1 or "
"any positive int. When it is -1, all sequences will be padded up to "
"the length of the longest one among them; when it a certain positive "
"value, it must be greater than the length of the longest original "
"sequence.")
.SetDefault(-1);
AddComment(R"DOC(
)DOC");
......
......@@ -17,7 +17,13 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_pad_grad,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -24,68 +26,24 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename DeviceContext, typename T>
struct CopyFunctor {
LoDTensor* lod_tensor_;
LoDTensor* pad_tensor_;
const LoD& ref_lod_;
const DeviceContext& ctx_;
bool is_lod_to_pad_;
CopyFunctor(LoDTensor* lod_tensor, const LoD& ref_lod, LoDTensor* pad_tensor,
const DeviceContext& ctx, bool is_lod_to_pad)
: lod_tensor_(lod_tensor),
pad_tensor_(pad_tensor),
ref_lod_(ref_lod),
ctx_(ctx),
is_lod_to_pad_(is_lod_to_pad) {}
void operator()() const {
/*
auto seq_num = ref_lod_.size() - 1;
auto max_len = pad_tensor_->dims()[0] / seq_num;
PADDLE_ENFORCE_EQ(max_len * seq_num, pad_tensor_->dims()[0],
"First dimension of padded tensor should be equal to "
"maximum sequence length mulplied by sequence number.");
for (size_t i = 1; i < ref_lod_.size(); ++i) {
auto seq_start = ref_lod_[i - 1];
auto seq_end = ref_lod_[i];
auto pad_start = (i - 1) * max_len;
auto pad_end = pad_start + (seq_end - seq_start);
auto sub_lod_tensor = lod_tensor_->Slice(seq_start, seq_end);
auto sub_pad_tensor = pad_tensor_->Slice(pad_start, pad_end);
if (is_lod_to_pad_) {
framework::TensorCopy(sub_lod_tensor, ctx.GetPlace(), &sub_pad_tensor);
} else {
framework::TensorCopy(sub_pad_tensor, ctx.GetPlace(), &sub_lod_tensor);
}
}
*/
}
};
template <typename DeviceContext, typename T>
class SequencePadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
/*
auto* x = ctx.Input<LoDTensor>("X");
auto* out_ptr = ctx.Output<LoDTensor>("Out");
out_ptr->mutable_data<T>(ctx.GetPlace());
const auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
// Resize();
const auto* pad_value = ctx.Input<LoDTensor>("PadValue");
const T* pad_value_data = pad_value->data<T>();
std::vector<T> pad_value_vec(pad_value_data,
pad_value_data + pad_value->numel());
T pad_value = static_cast<T>(ctx.Attr<float>("pad_value"));
int padded_length = ctx.Attr<int>("padded_length");
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x, *, false);
math::SetConstant<DeviceContext, T> set_func;
set_func(ctx.template device_context<DeviceContext>(), out_ptr, pad_value);
*/
ctx.template device_context<DeviceContext>(), *x, out, pad_value_vec,
padded_length, 0, false, math::kBatchLengthWidth);
}
};
......@@ -93,26 +51,17 @@ template <typename DeviceContext, typename T>
class SequencePadGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
/*
auto* x_ptr = ctx.Input<LoDTensor>("X");
auto* g_out_ptr = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* g_x_ptr = ctx.Output<LoDTensor>(framework::GradVarName("X"));
math::SetConstant<DeviceContext, T> set_func;
set_func(ctx.template device_context<DeviceContext>(),
g_x_ptr,
static_cast<T>(0));
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x) {
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(ctx.GetPlace());
auto& x_lod = x_ptr->lod();
auto& x_last_level_lod = x_lod[x_lod.size() - 1];
int padded_length = ctx.Attr<int>("padded_length");
CopyFunctor copy_func<DeviceContext, T>(g_out_ptr,
x_last_level_lod,
g_x_ptr,
ctx,
false);
copy_func();
*/
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x,
padded_length, 0, false, math::kBatchLengthWidth);
}
}
};
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestSequencePadOp(OpTest):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = -1
self.dtype = 'float32'
def set_data(self):
x_data = np.random.uniform(0.1, 0.5, self.x_shape).astype(self.dtype)
pad_value_data = np.array(self.pad_value).astype(self.dtype)
self.inputs = {
'X': (x_data, self.x_len_lod),
'PadValue': pad_value_data
}
self.attrs = {'padded_length': self.padded_length}
def compute(self):
# get padded length
padded_length = self.padded_length
x_len_lod_0 = self.x_len_lod[0]
if padded_length == -1:
max_seq_len = 0
for l in x_len_lod_0:
max_seq_len = max(max_seq_len, l)
padded_length = max_seq_len
# do padding
x_data = self.inputs['X'][0]
pad_value_data = self.inputs['PadValue']
if pad_value_data.shape == (1, ):
pad_value_data = np.broadcast_to(
pad_value_data, shape=x_data.shape[1:])
padded_sequences = []
start_idx = 0
for l in x_len_lod_0:
end_idx = start_idx + l
seq = x_data[start_idx:end_idx]
to_pad_len = padded_length - l
for _ in range(to_pad_len):
seq = np.append(seq, pad_value_data[np.newaxis, :], axis=0)
padded_sequences.append(seq)
start_idx = end_idx
out_len_lod = self.x_len_lod[:]
out_len_lod_0 = [padded_length] * len(x_len_lod_0)
out_len_lod[0] = out_len_lod_0
out_data = np.concatenate(padded_sequences, axis=0)
self.outputs = {'Out': (out_data, out_len_lod)}
def setUp(self):
self.op_type = 'sequence_pad'
self.set_attr()
self.set_data()
self.compute()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSequencePadOp2(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0, 2.0, 3.0, 4.0]
self.padded_length = -1
self.dtype = 'float32'
class TestSequencePadOp3(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = 7
self.dtype = 'float32'
class TestSequencePadOp4(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0, 2.0, 3.0, 4.0]
self.padded_length = 7
self.dtype = 'float32'
class TestSequencePadOp5(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 2, 2]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = -1
self.dtype = 'float32'
class TestSequencePadOp6(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 2, 2]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [[1.0, 2.0], [3.0, 4.0]]
self.padded_length = -1
self.dtype = 'float32'
class TestSequencePadOp7(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 2, 2]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = 7
self.dtype = 'float32'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册