diff --git a/paddle/operators/sequence_slice_op.cc b/paddle/operators/sequence_slice_op.cc index 3374f04269bbb153ee4681ef551f371c0497d5a9..2fae37e0bf643d59b458cd2bae67a53537e62ad5 100755 --- a/paddle/operators/sequence_slice_op.cc +++ b/paddle/operators/sequence_slice_op.cc @@ -90,11 +90,11 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { "a vector to describe the length of every input sequence for " "sub sequence item."); AddOutput("Out", - "(LoDTensor), The output of SequenceSliceOp."); + "(LoDTensor), the output of SequenceSliceOp."); AddComment(R"DOC( Sequence slice operator -The operator crop a subsequence from given sequence with given start offset and subsequence length. +The operator crops a subsequence from given sequence with given start offset and subsequence length. It only supports sequence (LoD Tensor with level number is 1). - Case: X = [[a1, a2; @@ -109,7 +109,7 @@ It only supports sequence (LoD Tensor with level number is 1). b1, b2] [e1, e2]] LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2) -NOTE: The length of the input, offset and length should be the same. The offset start from 0. +NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0. )DOC"); } }; diff --git a/paddle/operators/sequence_slice_op.h b/paddle/operators/sequence_slice_op.h index 7210f489e2638cc05b6ed076f3033d4b984ff2d7..cbb950b152a69d6d74dd32fd5afc02ee38603f1c 100755 --- a/paddle/operators/sequence_slice_op.h +++ b/paddle/operators/sequence_slice_op.h @@ -83,7 +83,8 @@ class SequenceSliceOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LT( lod[0][i] + offset_data[i] + length_data[i], lod[0][i + 1], - "The target tensor's length overflow")} + "The target tensor's length overflow") + } out->mutable_data(ctx.GetPlace()); auto out_lod = SequenceSliceLoD(*in, offset_data, length_data); @@ -140,27 +141,29 @@ class SequenceSliceGradOpKernel : public framework::OpKernel { auto lod = in->lod(); auto out_lod = out_grad->lod(); - x_grad->mutable_data(ctx.GetPlace()); - math::SetConstant set_zero; - set_zero(ctx.device_context(), x_grad, static_cast(0)); + if (x_grad) { + x_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), x_grad, static_cast(0)); - auto out_grad_stride = framework::stride(out_grad->dims()); + auto out_grad_stride = framework::stride(out_grad->dims()); - for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { - Tensor out_grad_t = - out_grad->Slice(static_cast(out_lod[0][i]), - static_cast(out_lod[0][i + 1])); - auto out_grad_stride = framework::stride(out_grad_t.dims()); + for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { + Tensor out_grad_t = + out_grad->Slice(static_cast(out_lod[0][i]), + static_cast(out_lod[0][i + 1])); + auto out_grad_stride = framework::stride(out_grad_t.dims()); - auto x_grad_stride = framework::stride(x_grad->dims()); + auto x_grad_stride = framework::stride(x_grad->dims()); - Tensor x_grad_t = x_grad->Slice( - static_cast(lod[0][i] + offset_data[i]), - static_cast(lod[0][i] + offset_data[i] + length_data[i])); + Tensor x_grad_t = x_grad->Slice( + static_cast(lod[0][i] + offset_data[i]), + static_cast(lod[0][i] + offset_data[i] + length_data[i])); - StridedMemcpy(ctx.device_context(), out_grad_t.data(), - out_grad_stride, out_grad_t.dims(), x_grad_stride, - x_grad_t.data()); + StridedMemcpy(ctx.device_context(), out_grad_t.data(), + out_grad_stride, out_grad_t.dims(), x_grad_stride, + x_grad_t.data()); + } } } };