diff --git a/paddle/operators/sequence_slice_op.cc b/paddle/operators/sequence_slice_op.cc new file mode 100755 index 0000000000000000000000000000000000000000..cbe0b4233160dd1f3ebdf6db8b5f6df392efdfe7 --- /dev/null +++ b/paddle/operators/sequence_slice_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_slice_op.h" + +namespace paddle { +namespace operators { + +class SequenceSliceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Offset"), + "Input(Offset) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Length"), + "Input(Length) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceSliceOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + + auto offset_dim = ctx->GetInputDim("Offset"); + auto length_dim = ctx->GetInputDim("Length"); + + PADDLE_ENFORCE_EQ( + offset_dim.size(), 2UL, + "Only support one level sequence now, The rank of offset must be 2."); + PADDLE_ENFORCE_EQ( + length_dim.size(), 2UL, + "Only support one level sequence now, The rank of Length must be 2."); + + // Initialize the output's dims to maximum, + // and re-set to real dims by the value of Offset and Length at kernel + ctx->SetOutputDim("Out", input_dims); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class SequenceSliceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceSliceOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor), " + "the input of SequenceSliceOp."); + AddInput("Offset", + "(Tensor), " + "a vector to describe the offset of every input sequence for " + "sub sequence item."); + AddInput("Length", + "(Tensor), " + "a vector to describe the length of every input sequence for " + "sub sequence item."); + AddOutput("Out", + "(LoDTensor), the output of SequenceSliceOp."); + AddComment(R"DOC( +Sequence slice operator + +The operator crops a subsequence from given sequence with given start offset and subsequence length. +It only supports sequence (LoD Tensor with level number is 1). +- Case: + X = [[a1, a2; + b1, b2; + c1, c2] + [d1, d2; + e1, e2]] + LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2) + Offset = [[0], [1]]; Length = [[2], [1]] + + Out = [[a1, a2; + b1, b2] + [e1, e2]] + LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2) +NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_slice, ops::SequenceSliceOp, ops::SequenceSliceOpMaker, + sequence_slice_grad, ops::SequenceSliceGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_slice, + ops::SequenceSliceOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_slice_grad, + ops::SequenceSliceGradOpKernel); diff --git a/paddle/operators/sequence_slice_op.cu b/paddle/operators/sequence_slice_op.cu new file mode 100755 index 0000000000000000000000000000000000000000..a9f59dadba74d900fa5cc0601fb5b264ea19e34d --- /dev/null +++ b/paddle/operators/sequence_slice_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_slice_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_slice, + ops::SequenceSliceOpKernel); +REGISTER_OP_GPU_KERNEL( + sequence_slice_grad, + ops::SequenceSliceGradOpKernel); diff --git a/paddle/operators/sequence_slice_op.h b/paddle/operators/sequence_slice_op.h new file mode 100755 index 0000000000000000000000000000000000000000..2c9b8464a1236a054cf1a38b9dc1d73588f8dd38 --- /dev/null +++ b/paddle/operators/sequence_slice_op.h @@ -0,0 +1,173 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +template +inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data, + const int64_t* length_data) { + auto out_lod = in.lod(); + size_t lod_offset = 0; + + auto n = in.lod()[0].size() - 1; + out_lod[0][0] = 0; + for (size_t i = 0; i < n; ++i) { + lod_offset += length_data[i]; + out_lod[0][i+1] = lod_offset; + } + return out_lod; +} + +template +class SequenceSliceOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* offset = ctx.Input("Offset"); + auto* length = ctx.Input("Length"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + auto n = lod[0].size() - 1; + + PADDLE_ENFORCE_EQ(lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + n, static_cast(length->dims()[0]), + "The size of input-sequence and length-array should be the same") + PADDLE_ENFORCE_EQ( + n, static_cast(offset->dims()[0]), + "The size of input-sequence and offset-array should be the same") + + const int64_t* offset_data = offset->data(); + const int64_t* length_data = length->data(); + framework::Tensor offset_cpu; + framework::Tensor length_cpu; + + if (platform::is_gpu_place(ctx.GetPlace())) { + offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); + offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); + offset_data = offset_cpu.data(); + + length_cpu.mutable_data(length->dims(), platform::CPUPlace()); + length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); + length_data = length_cpu.data(); + } + + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_LT(0, offset_data[i], + "The offset[%d] must greater than zero.", i) + PADDLE_ENFORCE_LT(0, length_data[i], + "The length[%d] must greater than zero.", i) + PADDLE_ENFORCE_LT( + lod[0][i] + offset_data[i] + length_data[i], + lod[0][i + 1], + "The target tensor's length overflow.") + } + + out->mutable_data(ctx.GetPlace()); + auto out_lod = SequenceSliceLoD(*in, offset_data, length_data); + auto out_dims = in->dims(); + out_dims[0] = out_lod[0][out_lod[0].size() - 1]; + out->Resize(out_dims); + out->set_lod(out_lod); + + auto in_stride = framework::stride(in->dims()); + auto out_stride = framework::stride(out->dims()); + + size_t out_offset = 0; + for (size_t i = 0; i < n; ++i) { + Tensor in_t = + in->Slice(static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + + length_data[i])); + + StridedMemcpy(ctx.device_context(), in_t.data(), + in_stride, in_t.dims(), out_stride, + out->data() + out_offset); + out_offset += length_data[i] * in_stride[0]; + } + } +}; + +template +class SequenceSliceGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* offset = ctx.Input("Offset"); + auto* length = ctx.Input("Length"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + + const int64_t* offset_data = offset->data(); + const int64_t* length_data = length->data(); + framework::Tensor offset_cpu; + framework::Tensor length_cpu; + + if (platform::is_gpu_place(ctx.GetPlace())) { + offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); + offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); + offset_data = offset_cpu.data(); + + length_cpu.mutable_data(length->dims(), platform::CPUPlace()); + length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); + length_data = length_cpu.data(); + } + + auto lod = in->lod(); + auto out_lod = out_grad->lod(); + + if (x_grad) { + x_grad->mutable_data(ctx.GetPlace()); + x_grad->set_lod(in->lod()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), x_grad, static_cast(0)); + + auto out_grad_stride = framework::stride(out_grad->dims()); + + for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { + Tensor out_grad_t = + out_grad->Slice(static_cast(out_lod[0][i]), + static_cast(out_lod[0][i + 1])); + auto out_grad_stride = framework::stride(out_grad_t.dims()); + + auto x_grad_stride = framework::stride(x_grad->dims()); + + Tensor x_grad_t = x_grad->Slice( + static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + length_data[i])); + + StridedMemcpy(ctx.device_context(), out_grad_t.data(), + out_grad_stride, out_grad_t.dims(), x_grad_stride, + x_grad_t.data()); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 6c09ca3d346e36b6f38ac5c89914e95a7383f719..a21f67a2d99e7eab39708e2a571d30d7e9f20ce6 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -11,7 +11,6 @@ test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_l test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer -test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/v2/fluid/tests/test_sequence_slice_op.py b/python/paddle/v2/fluid/tests/test_sequence_slice_op.py new file mode 100755 index 0000000000000000000000000000000000000000..4351d8e6d77c16e0012f9ae163b118fdbb793a8f --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_slice_op.py @@ -0,0 +1,45 @@ +import unittest +import numpy as np +import sys +from op_test import OpTest + +class TestSequenceSliceOp(OpTest): + def set_data(self): + self.init_test_case() + # only supprot one level LoD + x = np.random.random(self.x_dim).astype('float32') + lod = self.x_lod + offset = np.array(self.offset).astype("int64") + length = np.array(self.length).astype("int64") + + self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length} + outs = [] #np.zeros((100, 3, 2)).astype('float32') + out_lod = [[0]] + out_lod_offset = 0 + for i in range(len(offset)): + sub_x = x[lod[0][i] + offset[i, 0]: lod[0] + [i] + offset[i, 0] + length[i, 0], :] + out_lod_offset = out_lod_offset + len(sub_x) + outs.append(sub_x) + out_lod[0].append(out_lod_offset) + outs = np.concatenate(outs, axis=0) + self.outputs = {'Out': (outs, out_lod)} + + def init_test_case(self): + self.x_dim = (100, 3, 2) + self.x_lod = [[0, 20, 40, 60, 80, 100]] + self.offset = [[1], [2], [3], [4], [5]] + self.length = [[10], [8], [6], [4], [2]] + + def setUp(self): + self.op_type = "sequence_slice" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + +if __name__ == '__main__': + unittest.main()