提交 b24afd81 编写于 作者: W wanghaox

update the sub_sequence_op to sequence_slice_op code.

上级 f23d6cc4
...@@ -12,37 +12,39 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,37 +12,39 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sub_sequence_op.h" #include "paddle/operators/sequence_slice_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SubSequenceOp : public framework::OperatorWithKernel { class SequenceSliceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SubSequenceOp should not be null."); "Input(X) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Offset"),
"Input(Offset) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Length"),
"Input(Length) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SubSequenceOp should not be null."); "Output(Out) of SequenceSliceOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offset"); ctx->SetOutputDim("Out", input_dims);
auto sizes = ctx->Attrs().Get<std::vector<int>>("size");
auto dim_0 = 0;
for (size_t i = 0; i < sizes.size(); ++i) {
dim_0 += sizes[i];
} }
framework::DDim out_dims = input_dims; protected:
out_dims[0] = dim_0; framework::OpKernelType GetKernelType(
ctx->SetOutputDim("Out", out_dims); const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
} }
}; };
class SubSequenceGradOp : public framework::OperatorWithKernel { class SequenceSliceGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -53,34 +55,50 @@ class SubSequenceGradOp : public framework::OperatorWithKernel { ...@@ -53,34 +55,50 @@ class SubSequenceGradOp : public framework::OperatorWithKernel {
"The gradient of X should not be null."); "The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
} }
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
class SubSequenceOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SubSequenceOpMaker(framework::OpProto* proto, SequenceSliceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor), " AddInput("X",
"the variable-length input of SubSequenceOp"); "(LoDTensor), "
AddAttr<std::vector<int>>( "the input of SequenceSliceOp.");
"offset", AddInput("Offset",
"A list<int> to describes offset for sub sequence item."); "(Tensor), "
AddAttr<std::vector<int>>( "A vector<int> to describes offset for sub sequence item.");
"size", AddInput("Length",
"A list<int> to describes size for sub sequence item."); "(Tensor), "
"A vector<int> to describes length for sub sequence item.");
AddOutput("Out", AddOutput("Out",
"(Tensor), Variable-length output of " "(LoDTensor), output of sequence slice Op.");
"sequence_concat Op.");
AddComment(R"DOC( AddComment(R"DOC(
Sub Sequence operator Sequence slice operator
The operator crop a subsequence from given sequence with given start offset and subsequence length.
The operator crop a subsequence from given sequence with given start offset and subsequence size.
It only supports sequence (LoD Tensor with level number is 1). It only supports sequence (LoD Tensor with level number is 1).
- Case: - Case:
LoD(x) = {{0, 3, 6, 10}}; Dims(x0) = (10, 3, 2) X = [[a1, a2;
offset = (0, 1, 1); size = (2, 1, 2) b1, b2;
LoD(Out) = {{0, 2, 3, 5}}; Dims(Out) = (5,3,2) c1, c2]
NOTE: The length of the input, offset and size should be the same. The offset start from 0. [d1, d2;
e1, e2]]
LoD(X) = {{0, 3, 5}}; Dims(X) = (4, 1, 2)
Offset = (0, 1); Length = (2, 1)
Out = [[a1, a2;
b1, b2]
[e1, e2]]
LoD(Out) = {{0, 2, 3}}
NOTE: The length of the input, offset and length should be the same. The offset start from 0.
)DOC"); )DOC");
} }
}; };
...@@ -89,11 +107,11 @@ NOTE: The length of the input, offset and size should be the same. The offset st ...@@ -89,11 +107,11 @@ NOTE: The length of the input, offset and size should be the same. The offset st
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sub_sequence, ops::SubSequenceOp, ops::SubSequenceOpMaker, REGISTER_OP(sequence_slice, ops::SequenceSliceOp, ops::SequenceSliceOpMaker,
sub_sequence_grad, ops::SubSequenceGradOp); sequence_slice_grad, ops::SequenceSliceGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sub_sequence, sequence_slice,
ops::SubSequenceOpKernel<paddle::platform::CPUPlace, float>); ops::SequenceSliceOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sub_sequence_grad, sequence_slice_grad,
ops::SubSequenceGradOpKernel<paddle::platform::CPUPlace, float>); ops::SequenceSliceGradOpKernel<paddle::platform::CPUPlace, float>);
...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #include "paddle/operators/sequence_slice_op.h"
#include "paddle/operators/sub_sequence_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sub_sequence, sequence_slice,
ops::SubSequenceOpKernel<paddle::platform::GPUPlace, float>); ops::SequenceSliceOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sub_sequence_grad, sequence_slice_grad,
ops::SubSequenceGradOpKernel<paddle::platform::GPUPlace, float>); ops::SequenceSliceGradOpKernel<paddle::platform::GPUPlace, float>);
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h" #include "paddle/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
...@@ -25,109 +25,124 @@ using LoDTensor = framework::LoDTensor; ...@@ -25,109 +25,124 @@ using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD; using LoD = framework::LoD;
template <typename T> template <typename T>
LoD subsequenceLoD(const T* in, const std::vector<int> offsets, LoD SequenceSliceLoD(const T& in, const int64_t* offset_data,
const std::vector<int> sizes) { const int64_t* length_data) {
auto out_lod = in->lod(); auto out_lod = in.lod();
size_t lod_offset = 0; size_t lod_offset = 0;
auto n = in->lod()[0].size() - 1; auto n = in.lod()[0].size() - 1;
out_lod[0][0] = 0; out_lod[0][0] = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
lod_offset += sizes[i]; lod_offset += length_data[i];
out_lod[0][i+1] = lod_offset; out_lod[0][i+1] = lod_offset;
} }
return out_lod; return out_lod;
} }
template <typename Place, typename T> template <typename Place, typename T>
class SubSequenceOpKernel : public framework::OpKernel<T> { class SequenceSliceOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X"); auto* in = ctx.Input<LoDTensor>("X");
std::vector<int> offsets = ctx.Attr<std::vector<int>>("offset"); auto* offset = ctx.Input<Tensor>("Offset");
std::vector<int> sizes = ctx.Attr<std::vector<int>>("size"); auto* length = ctx.Input<Tensor>("Length");
auto* out = ctx.Output<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
auto offset_len = offsets.size(); const int64_t* offset_data = offset->data<int64_t>();
auto size_len = sizes.size(); const int64_t* length_data = length->data<int64_t>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor offset_cpu;
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>();
framework::Tensor length_cpu;
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
length_data = length_cpu.data<int64_t>();
}
auto lod = in->lod(); auto lod = in->lod();
auto n = lod[0].size() - 1; auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(n, offset_len, PADDLE_ENFORCE_EQ(offset->dims().size(), 1UL,
"The length of input and offset should be the same") "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(n, size_len, PADDLE_ENFORCE_EQ(length->dims().size(), 1UL,
"The length of input and size should be the same") "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, length->dims()[0],
"The size of input-sequence and length-array should be the same")
PADDLE_ENFORCE_EQ(
n, offset->dims()[0],
"The size of input-sequence and offset-array should be the same")
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
auto offset = offsets[i]; PADDLE_ENFORCE_LT(0, offset_data[i], "The offset must greater than zero")
auto size = sizes[i]; PADDLE_ENFORCE_LT(0, length_data[i], "The length must greater than zero")
PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1], PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
"The target tensor's length overflow") lod[0][i + 1], "The target tensor's length overflow")
} }
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto out_lod = subsequenceLoD(in, offsets, sizes); auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
out->set_lod(out_lod); out->set_lod(out_lod);
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), out, static_cast<T>(0));
auto in_stride = framework::stride(in->dims()); auto in_stride = framework::stride(in->dims());
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
size_t out_offset = 0; size_t out_offset = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
auto offset = offsets[i]; Tensor in_t =
auto size = sizes[i]; in->Slice(static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] +
Tensor in_t = in->Slice(static_cast<int>(lod[0][i] + offset), length_data[i]));
static_cast<int>(lod[0][i] + offset + size));
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
in_stride, in_t.dims(), out_stride, in_stride, in_t.dims(), out_stride,
out->data<T>() + out_offset); out->data<T>() + out_offset);
out_offset += size * in_stride[0]; out_offset += length_data[i] * in_stride[0];
} }
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SubSequenceGradOpKernel : public framework::OpKernel<T> { class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X"); auto* in = ctx.Input<LoDTensor>("X");
std::vector<int> offsets = ctx.Attr<std::vector<int>>("offset"); auto* offset = ctx.Input<Tensor>("Offset");
std::vector<int> sizes = ctx.Attr<std::vector<int>>("size"); auto* length = ctx.Input<Tensor>("Length");
auto* out_grad = auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")); ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* x_grad = auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto offset_len = offsets.size(); const int64_t* offset_data = offset->data<int64_t>();
auto size_len = sizes.size(); const int64_t* length_data = length->data<int64_t>();
auto lod = in->lod(); if (platform::is_gpu_place(ctx.GetPlace())) {
auto n = lod[0].size() - 1; framework::Tensor offset_cpu;
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>();
// check input data format framework::Tensor length_cpu;
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
PADDLE_ENFORCE_EQ(n, offset_len, length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
"The length of input and offset should be the same") length_data = length_cpu.data<int64_t>();
PADDLE_ENFORCE_EQ(n, size_len,
"The length of input and size should be the same")
for (size_t i = 0; i < n; ++i) {
auto offset = offsets[i];
auto size = sizes[i];
PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1],
"The target tensor's length overflow")
} }
auto out_lod = subsequenceLoD(in, offsets, sizes); auto lod = in->lod();
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
x_grad->set_lod(lod); x_grad->set_lod(lod);
x_grad->mutable_data<T>(ctx.GetPlace()); x_grad->mutable_data<T>(ctx.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*x_grad); math::SetConstant<Place, T> set_zero;
temp.device(ctx.GetEigenDevice<Place>()) = temp.constant(static_cast<T>(0)); set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims()); auto out_grad_stride = framework::stride(out_grad->dims());
...@@ -139,11 +154,9 @@ class SubSequenceGradOpKernel : public framework::OpKernel<T> { ...@@ -139,11 +154,9 @@ class SubSequenceGradOpKernel : public framework::OpKernel<T> {
auto x_grad_stride = framework::stride(x_grad->dims()); auto x_grad_stride = framework::stride(x_grad->dims());
auto offset = offsets[i]; Tensor x_grad_t = x_grad->Slice(
auto size = sizes[i]; static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
Tensor x_grad_t = x_grad->Slice(static_cast<int>(lod[0][i] + offset),
static_cast<int>(lod[0][i] + offset + size));
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(), StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride, out_grad_stride, out_grad_t.dims(), x_grad_stride,
......
...@@ -3,31 +3,29 @@ import numpy as np ...@@ -3,31 +3,29 @@ import numpy as np
import sys import sys
from op_test import OpTest from op_test import OpTest
class TestSubSequenceOp(OpTest): class TestSequenceSliceOp(OpTest):
def set_data(self): def set_data(self):
# only supprot one level LoD # only supprot one level LoD
x = np.random.random((100, 3, 2)).astype('float32') x = np.random.random((100, 3, 2)).astype('float32')
lod = [[0, 20, 40, 60, 80, 100]] lod = [[0, 20, 40, 60, 80, 100]]
offsets = np.array([1, 2, 3, 4, 5]).flatten() offset = np.array([1, 2, 3, 4, 5]).flatten().astype("int64")
sizes = np.array([10, 8, 6, 4, 2]).flatten() length = np.array([10, 8, 6, 4, 2]).flatten().astype("int64")
self.inputs = {'X': (x, lod)} self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length}
self.attrs = {'offset': offsets, 'size': sizes} outs = np.zeros((100, 3, 2)).astype('float32')
outs = []
out_lod = [[0]] out_lod = [[0]]
out_lod_offset = 0 out_lod_offset = 0
for i in range(len(offsets)): for i in range(len(offset)):
sub_x = x[lod[0][i] + offsets[i]: lod[0] sub_x = x[lod[0][i] + offset[i]: lod[0]
[i] + offsets[i] + sizes[i], :] [i] + offset[i] + length[i], :]
outs.append(sub_x)
out_lod_offset = out_lod_offset + len(sub_x) out_lod_offset = out_lod_offset + len(sub_x)
outs[out_lod[0][i]: out_lod_offset, :] = sub_x
out_lod[0].append(out_lod_offset) out_lod[0].append(out_lod_offset)
outs = np.concatenate(outs, axis=0) self.outputs = {'Out': (outs, out_lod)}
self.outputs = {'Out': outs}
def setUp(self): def setUp(self):
self.op_type = "sub_sequence" self.op_type = "sequence_slice"
self.set_data() self.set_data()
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册