diff --git a/paddle/operators/sequence_slice_op.cc b/paddle/operators/sequence_slice_op.cc index f1e1c862a0ae04df745a55b42c77cc8c23e74860..a7e659b76338c893db546ad581669b32fc034d5c 100755 --- a/paddle/operators/sequence_slice_op.cc +++ b/paddle/operators/sequence_slice_op.cc @@ -12,37 +12,39 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/sub_sequence_op.h" +#include "paddle/operators/sequence_slice_op.h" namespace paddle { namespace operators { -class SubSequenceOp : public framework::OperatorWithKernel { +class SequenceSliceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SubSequenceOp should not be null."); + "Input(X) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Offset"), + "Input(Offset) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Length"), + "Input(Length) of SequenceSliceOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SubSequenceOp should not be null."); + "Output(Out) of SequenceSliceOp should not be null."); auto input_dims = ctx->GetInputDim("X"); - auto offsets = ctx->Attrs().Get>("offset"); - auto sizes = ctx->Attrs().Get>("size"); - - auto dim_0 = 0; - for (size_t i = 0; i < sizes.size(); ++i) { - dim_0 += sizes[i]; + ctx->SetOutputDim("Out", input_dims); } - framework::DDim out_dims = input_dims; - out_dims[0] = dim_0; - ctx->SetOutputDim("Out", out_dims); + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); } }; -class SubSequenceGradOp : public framework::OperatorWithKernel { +class SequenceSliceGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -53,34 +55,50 @@ class SubSequenceGradOp : public framework::OperatorWithKernel { "The gradient of X should not be null."); ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; -class SubSequenceOpMaker : public framework::OpProtoAndCheckerMaker { +class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { public: - SubSequenceOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) + SequenceSliceOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(LoDTensor), " - "the variable-length input of SubSequenceOp"); - AddAttr>( - "offset", - "A list to describes offset for sub sequence item."); - AddAttr>( - "size", - "A list to describes size for sub sequence item."); + AddInput("X", + "(LoDTensor), " + "the input of SequenceSliceOp."); + AddInput("Offset", + "(Tensor), " + "A vector to describes offset for sub sequence item."); + AddInput("Length", + "(Tensor), " + "A vector to describes length for sub sequence item."); AddOutput("Out", - "(Tensor), Variable-length output of " - "sequence_concat Op."); + "(LoDTensor), output of sequence slice Op."); AddComment(R"DOC( -Sub Sequence operator - -The operator crop a subsequence from given sequence with given start offset and subsequence size. +Sequence slice operator +The operator crop a subsequence from given sequence with given start offset and subsequence length. It only supports sequence (LoD Tensor with level number is 1). - Case: - LoD(x) = {{0, 3, 6, 10}}; Dims(x0) = (10, 3, 2) - offset = (0, 1, 1); size = (2, 1, 2) - LoD(Out) = {{0, 2, 3, 5}}; Dims(Out) = (5,3,2) -NOTE: The length of the input, offset and size should be the same. The offset start from 0. + X = [[a1, a2; + b1, b2; + c1, c2] + [d1, d2; + e1, e2]] + LoD(X) = {{0, 3, 5}}; Dims(X) = (4, 1, 2) + Offset = (0, 1); Length = (2, 1) + + Out = [[a1, a2; + b1, b2] + [e1, e2]] + LoD(Out) = {{0, 2, 3}} +NOTE: The length of the input, offset and length should be the same. The offset start from 0. )DOC"); } }; @@ -89,11 +107,11 @@ NOTE: The length of the input, offset and size should be the same. The offset st } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sub_sequence, ops::SubSequenceOp, ops::SubSequenceOpMaker, - sub_sequence_grad, ops::SubSequenceGradOp); +REGISTER_OP(sequence_slice, ops::SequenceSliceOp, ops::SequenceSliceOpMaker, + sequence_slice_grad, ops::SequenceSliceGradOp); REGISTER_OP_CPU_KERNEL( - sub_sequence, - ops::SubSequenceOpKernel); + sequence_slice, + ops::SequenceSliceOpKernel); REGISTER_OP_CPU_KERNEL( - sub_sequence_grad, - ops::SubSequenceGradOpKernel); + sequence_slice_grad, + ops::SequenceSliceGradOpKernel); diff --git a/paddle/operators/sequence_slice_op.cu b/paddle/operators/sequence_slice_op.cu index d4127347cb622549ce11a7e61e6f518bb31a6611..a9f59dadba74d900fa5cc0601fb5b264ea19e34d 100755 --- a/paddle/operators/sequence_slice_op.cu +++ b/paddle/operators/sequence_slice_op.cu @@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - -#include "paddle/operators/sub_sequence_op.h" +#include "paddle/operators/sequence_slice_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - sub_sequence, - ops::SubSequenceOpKernel); + sequence_slice, + ops::SequenceSliceOpKernel); REGISTER_OP_GPU_KERNEL( - sub_sequence_grad, - ops::SubSequenceGradOpKernel); + sequence_slice_grad, + ops::SequenceSliceGradOpKernel); diff --git a/paddle/operators/sequence_slice_op.h b/paddle/operators/sequence_slice_op.h index cd291a382b7ad1ea7da377de5572acb9b1e90652..7599a0abf402a90a8b843ce0143ff597b8b80333 100755 --- a/paddle/operators/sequence_slice_op.h +++ b/paddle/operators/sequence_slice_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/strided_memcpy.h" namespace paddle { @@ -25,109 +25,124 @@ using LoDTensor = framework::LoDTensor; using LoD = framework::LoD; template -LoD subsequenceLoD(const T* in, const std::vector offsets, - const std::vector sizes) { - auto out_lod = in->lod(); +LoD SequenceSliceLoD(const T& in, const int64_t* offset_data, + const int64_t* length_data) { + auto out_lod = in.lod(); size_t lod_offset = 0; - auto n = in->lod()[0].size() - 1; + auto n = in.lod()[0].size() - 1; out_lod[0][0] = 0; for (size_t i = 0; i < n; ++i) { - lod_offset += sizes[i]; + lod_offset += length_data[i]; out_lod[0][i+1] = lod_offset; } return out_lod; } template -class SubSequenceOpKernel : public framework::OpKernel { +class SequenceSliceOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - std::vector offsets = ctx.Attr>("offset"); - std::vector sizes = ctx.Attr>("size"); + auto* offset = ctx.Input("Offset"); + auto* length = ctx.Input("Length"); auto* out = ctx.Output("Out"); - auto offset_len = offsets.size(); - auto size_len = sizes.size(); + const int64_t* offset_data = offset->data(); + const int64_t* length_data = length->data(); + + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor offset_cpu; + offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); + offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); + offset_data = offset_cpu.data(); + + framework::Tensor length_cpu; + length_cpu.mutable_data(length->dims(), platform::CPUPlace()); + length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); + length_data = length_cpu.data(); + } auto lod = in->lod(); auto n = lod[0].size() - 1; PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); - PADDLE_ENFORCE_EQ(n, offset_len, - "The length of input and offset should be the same") - PADDLE_ENFORCE_EQ(n, size_len, - "The length of input and size should be the same") + PADDLE_ENFORCE_EQ(offset->dims().size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(length->dims().size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + n, length->dims()[0], + "The size of input-sequence and length-array should be the same") + PADDLE_ENFORCE_EQ( + n, offset->dims()[0], + "The size of input-sequence and offset-array should be the same") for (size_t i = 0; i < n; ++i) { - auto offset = offsets[i]; - auto size = sizes[i]; - PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1], - "The target tensor's length overflow") + PADDLE_ENFORCE_LT(0, offset_data[i], "The offset must greater than zero") + PADDLE_ENFORCE_LT(0, length_data[i], "The length must greater than zero") + PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i], + lod[0][i + 1], "The target tensor's length overflow") } out->mutable_data(ctx.GetPlace()); - auto out_lod = subsequenceLoD(in, offsets, sizes); + auto out_lod = SequenceSliceLoD(*in, offset_data, length_data); out->set_lod(out_lod); + math::SetConstant set_zero; + set_zero(ctx.device_context(), out, static_cast(0)); auto in_stride = framework::stride(in->dims()); auto out_stride = framework::stride(out->dims()); size_t out_offset = 0; for (size_t i = 0; i < n; ++i) { - auto offset = offsets[i]; - auto size = sizes[i]; - - Tensor in_t = in->Slice(static_cast(lod[0][i] + offset), - static_cast(lod[0][i] + offset + size)); + Tensor in_t = + in->Slice(static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + + length_data[i])); StridedMemcpy(ctx.device_context(), in_t.data(), in_stride, in_t.dims(), out_stride, out->data() + out_offset); - out_offset += size * in_stride[0]; + out_offset += length_data[i] * in_stride[0]; } } }; template -class SubSequenceGradOpKernel : public framework::OpKernel { +class SequenceSliceGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - std::vector offsets = ctx.Attr>("offset"); - std::vector sizes = ctx.Attr>("size"); + auto* offset = ctx.Input("Offset"); + auto* length = ctx.Input("Length"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto offset_len = offsets.size(); - auto size_len = sizes.size(); + const int64_t* offset_data = offset->data(); + const int64_t* length_data = length->data(); - auto lod = in->lod(); - auto n = lod[0].size() - 1; + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor offset_cpu; + offset_cpu.mutable_data(offset->dims(), platform::CPUPlace()); + offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); + offset_data = offset_cpu.data(); - // check input data format - PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); - PADDLE_ENFORCE_EQ(n, offset_len, - "The length of input and offset should be the same") - PADDLE_ENFORCE_EQ(n, size_len, - "The length of input and size should be the same") - - for (size_t i = 0; i < n; ++i) { - auto offset = offsets[i]; - auto size = sizes[i]; - PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1], - "The target tensor's length overflow") + framework::Tensor length_cpu; + length_cpu.mutable_data(length->dims(), platform::CPUPlace()); + length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); + length_data = length_cpu.data(); } - auto out_lod = subsequenceLoD(in, offsets, sizes); + auto lod = in->lod(); + auto out_lod = SequenceSliceLoD(*in, offset_data, length_data); x_grad->set_lod(lod); x_grad->mutable_data(ctx.GetPlace()); - auto temp = framework::EigenVector::Flatten(*x_grad); - temp.device(ctx.GetEigenDevice()) = temp.constant(static_cast(0)); + math::SetConstant set_zero; + set_zero(ctx.device_context(), x_grad, static_cast(0)); auto out_grad_stride = framework::stride(out_grad->dims()); @@ -139,11 +154,9 @@ class SubSequenceGradOpKernel : public framework::OpKernel { auto x_grad_stride = framework::stride(x_grad->dims()); - auto offset = offsets[i]; - auto size = sizes[i]; - - Tensor x_grad_t = x_grad->Slice(static_cast(lod[0][i] + offset), - static_cast(lod[0][i] + offset + size)); + Tensor x_grad_t = x_grad->Slice( + static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + length_data[i])); StridedMemcpy(ctx.device_context(), out_grad_t.data(), out_grad_stride, out_grad_t.dims(), x_grad_stride, diff --git a/python/paddle/v2/framework/tests/test_sequence_slice_op.py b/python/paddle/v2/framework/tests/test_sequence_slice_op.py index 73d81947bba95e1bae52cceea5ba73de40aec045..47b616b743427797a2dc47f9e7839ab220121224 100755 --- a/python/paddle/v2/framework/tests/test_sequence_slice_op.py +++ b/python/paddle/v2/framework/tests/test_sequence_slice_op.py @@ -3,31 +3,29 @@ import numpy as np import sys from op_test import OpTest -class TestSubSequenceOp(OpTest): +class TestSequenceSliceOp(OpTest): def set_data(self): # only supprot one level LoD x = np.random.random((100, 3, 2)).astype('float32') lod = [[0, 20, 40, 60, 80, 100]] - offsets = np.array([1, 2, 3, 4, 5]).flatten() - sizes = np.array([10, 8, 6, 4, 2]).flatten() + offset = np.array([1, 2, 3, 4, 5]).flatten().astype("int64") + length = np.array([10, 8, 6, 4, 2]).flatten().astype("int64") - self.inputs = {'X': (x, lod)} - self.attrs = {'offset': offsets, 'size': sizes} - outs = [] + self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length} + outs = np.zeros((100, 3, 2)).astype('float32') out_lod = [[0]] out_lod_offset = 0 - for i in range(len(offsets)): - sub_x = x[lod[0][i] + offsets[i]: lod[0] - [i] + offsets[i] + sizes[i], :] - outs.append(sub_x) + for i in range(len(offset)): + sub_x = x[lod[0][i] + offset[i]: lod[0] + [i] + offset[i] + length[i], :] out_lod_offset = out_lod_offset + len(sub_x) + outs[out_lod[0][i]: out_lod_offset, :] = sub_x out_lod[0].append(out_lod_offset) - outs = np.concatenate(outs, axis=0) - self.outputs = {'Out': outs} + self.outputs = {'Out': (outs, out_lod)} def setUp(self): - self.op_type = "sub_sequence" + self.op_type = "sequence_slice" self.set_data() def test_check_output(self):