提交 23b0388f 编写于 作者: W wanghaox

add sub sequence operator code and unittest

上级 ce08645d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sub_sequence_op.h"
namespace paddle {
namespace operators {
class SubSequenceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SubSequenceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SubSequenceOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offset");
auto sizes = ctx->Attrs().Get<std::vector<int>>("size");
auto dim_0 = 0;
for (size_t i = 0; i < sizes.size(); ++i) {
dim_0 += sizes[i];
}
framework::DDim out_dims = input_dims;
out_dims[0] = dim_0;
ctx->SetOutputDim("Out", out_dims);
}
};
class SubSequenceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};
class SubSequenceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SubSequenceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor), "
"the variable-length input of SubSequenceOp");
AddAttr<std::vector<int>>(
"offset",
"A list<int> to describes offset for sub sequence item.");
AddAttr<std::vector<int>>(
"size",
"A list<int> to describes size for sub sequence item.");
AddOutput("Out",
"(Tensor), Variable-length output of "
"sequence_concat Op.");
AddComment(R"DOC(
Sub Sequence operator
The operator crop a subsequence from given sequence with given start offset and subsequence size.
It only supports sequence (LoD Tensor with level number is 1).
- Case:
LoD(x) = {{0, 3, 6, 10}}; Dims(x0) = (10, 3, 2)
offset = (0, 1, 1); size = (2, 1, 2)
LoD(Out) = {{0, 2, 3, 5}}; Dims(Out) = (5,3,2)
NOTE: The length of the input, offset and size should be the same. The offset start from 0.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sub_sequence, ops::SubSequenceOp, ops::SubSequenceOpMaker,
sub_sequence_grad, ops::SubSequenceGradOp);
REGISTER_OP_CPU_KERNEL(
sub_sequence,
ops::SubSequenceOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sub_sequence_grad,
ops::SubSequenceGradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sub_sequence_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sub_sequence,
ops::SubSequenceOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sub_sequence_grad,
ops::SubSequenceGradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename T>
LoD subsequenceLoD(const T* in, const std::vector<int> offsets,
const std::vector<int> sizes) {
auto out_lod = in->lod();
size_t lod_offset = 0;
auto n = in->lod()[0].size() - 1;
out_lod[0][0] = 0;
for (size_t i = 0; i < n; ++i) {
lod_offset += sizes[i];
out_lod[0][i+1] = lod_offset;
}
return out_lod;
}
template <typename Place, typename T>
class SubSequenceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
std::vector<int> offsets = ctx.Attr<std::vector<int>>("offset");
std::vector<int> sizes = ctx.Attr<std::vector<int>>("size");
auto* out = ctx.Output<LoDTensor>("Out");
auto offset_len = offsets.size();
auto size_len = sizes.size();
auto lod = in->lod();
auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(n, offset_len,
"The length of input and offset should be the same")
PADDLE_ENFORCE_EQ(n, size_len,
"The length of input and size should be the same")
for (size_t i = 0; i < n; ++i) {
auto offset = offsets[i];
auto size = sizes[i];
PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1],
"The target tensor's length overflow")
}
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = subsequenceLoD(in, offsets, sizes);
out->set_lod(out_lod);
auto in_stride = framework::stride(in->dims());
auto out_stride = framework::stride(out->dims());
size_t out_offset = 0;
for (size_t i = 0; i < n; ++i) {
auto offset = offsets[i];
auto size = sizes[i];
Tensor in_t = in->Slice(static_cast<int>(lod[0][i] + offset),
static_cast<int>(lod[0][i] + offset + size));
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
in_stride, in_t.dims(), out_stride,
out->data<T>() + out_offset);
out_offset += size * in_stride[0];
}
}
};
template <typename Place, typename T>
class SubSequenceGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
std::vector<int> offsets = ctx.Attr<std::vector<int>>("offset");
std::vector<int> sizes = ctx.Attr<std::vector<int>>("size");
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto offset_len = offsets.size();
auto size_len = sizes.size();
auto lod = in->lod();
auto n = lod[0].size() - 1;
// check input data format
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(n, offset_len,
"The length of input and offset should be the same")
PADDLE_ENFORCE_EQ(n, size_len,
"The length of input and size should be the same")
for (size_t i = 0; i < n; ++i) {
auto offset = offsets[i];
auto size = sizes[i];
PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1],
"The target tensor's length overflow")
}
auto out_lod = subsequenceLoD(in, offsets, sizes);
x_grad->set_lod(lod);
x_grad->mutable_data<T>(ctx.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*x_grad);
temp.device(ctx.GetEigenDevice<Place>()) = temp.constant(static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]),
static_cast<int>(out_lod[0][i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
auto x_grad_stride = framework::stride(x_grad->dims());
auto offset = offsets[i];
auto size = sizes[i];
Tensor x_grad_t = x_grad->Slice(static_cast<int>(lod[0][i] + offset),
static_cast<int>(lod[0][i] + offset + size));
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
import sys
from op_test import OpTest
class TestSubSequenceOp(OpTest):
def set_data(self):
# only supprot one level LoD
x = np.random.random((100, 3, 2)).astype('float32')
lod = [[0, 20, 40, 60, 80, 100]]
offsets = np.array([1, 2, 3, 4, 5]).flatten()
sizes = np.array([10, 8, 6, 4, 2]).flatten()
self.inputs = {'X': (x, lod)}
self.attrs = {'offset': offsets, 'size': sizes}
outs = []
out_lod = [[0]]
out_lod_offset = 0
for i in range(len(offsets)):
sub_x = x[lod[0][i] + offsets[i]: lod[0]
[i] + offsets[i] + sizes[i], :]
outs.append(sub_x)
out_lod_offset = out_lod_offset + len(sub_x)
out_lod[0].append(out_lod_offset)
outs = np.concatenate(outs, axis=0)
self.outputs = {'Out': outs}
def setUp(self):
self.op_type = "sub_sequence"
self.set_data()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册