sequence_slice_op.h 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/op_registry.h"
17
#include "paddle/operators/math/math_function.h"
18 19 20 21 22 23 24 25 26 27
#include "paddle/operators/strided_memcpy.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;

template <typename T>
28 29
inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data,
                           const int64_t* length_data) {
30
  auto out_lod = in.lod();
31 32
  size_t lod_offset = 0;

33
  auto n = in.lod()[0].size() - 1;
34 35
  out_lod[0][0] = 0;
  for (size_t i = 0; i < n; ++i) {
36
    lod_offset += length_data[i];
37 38 39 40 41 42
    out_lod[0][i+1] = lod_offset;
  }
  return out_lod;
}

template <typename Place, typename T>
43
class SequenceSliceOpKernel : public framework::OpKernel<T> {
44 45 46
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<LoDTensor>("X");
47 48
    auto* offset = ctx.Input<Tensor>("Offset");
    auto* length = ctx.Input<Tensor>("Length");
49 50
    auto* out = ctx.Output<LoDTensor>("Out");

51 52 53 54 55 56 57 58 59 60 61 62
    auto lod = in->lod();
    auto n = lod[0].size() - 1;

    PADDLE_ENFORCE_EQ(lod.size(), 1UL,
                      "Only support one level sequence now.");
    PADDLE_ENFORCE_EQ(
        n, length->dims()[0],
        "The size of input-sequence and length-array should be the same")
    PADDLE_ENFORCE_EQ(
        n, offset->dims()[0],
        "The size of input-sequence and offset-array should be the same")

63 64
    const int64_t* offset_data = offset->data<int64_t>();
    const int64_t* length_data = length->data<int64_t>();
65 66
    framework::Tensor offset_cpu;
    framework::Tensor length_cpu;
67 68 69 70 71 72 73 74 75 76

    if (platform::is_gpu_place(ctx.GetPlace())) {
      offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
      offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
      offset_data = offset_cpu.data<int64_t>();

      length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
      length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
      length_data = length_cpu.data<int64_t>();
    }
77 78

    for (size_t i = 0; i < n; ++i) {
79 80 81 82 83 84 85
      PADDLE_ENFORCE_LT(0, offset_data[i],
                "The offset must greater than zero")
      PADDLE_ENFORCE_LT(0, length_data[i],
                "The length must greater than zero")
      PADDLE_ENFORCE_LT(
          lod[0][i] + offset_data[i] + length_data[i],
          lod[0][i + 1],
W
wanghaox 已提交
86 87
          "The target tensor's length overflow")
    }
88 89

    out->mutable_data<T>(ctx.GetPlace());
90
    auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
91 92 93
    auto out_dims = in->dims();
    out_dims[0] = out_lod[0][out_lod[0].size() - 1];
    out->Resize(out_dims);
94 95 96 97 98 99 100
    out->set_lod(out_lod);

    auto in_stride = framework::stride(in->dims());
    auto out_stride = framework::stride(out->dims());

    size_t out_offset = 0;
    for (size_t i = 0; i < n; ++i) {
101 102 103
      Tensor in_t =
          in->Slice(static_cast<int>(lod[0][i] + offset_data[i]),
                    static_cast<int>(lod[0][i] + offset_data[i] +
104
                                     length_data[i]));
105 106 107 108

      StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
                       in_stride, in_t.dims(), out_stride,
                       out->data<T>() + out_offset);
109
      out_offset += length_data[i] * in_stride[0];
110 111 112 113 114
    }
  }
};

template <typename Place, typename T>
115
class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
116 117 118
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<LoDTensor>("X");
119 120
    auto* offset = ctx.Input<Tensor>("Offset");
    auto* length = ctx.Input<Tensor>("Length");
121 122 123 124 125
    auto* out_grad =
        ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto* x_grad =
        ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));

126 127
    const int64_t* offset_data = offset->data<int64_t>();
    const int64_t* length_data = length->data<int64_t>();
W
wanghaox 已提交
128 129
    framework::Tensor offset_cpu;
    framework::Tensor length_cpu;
130

131 132 133 134
    if (platform::is_gpu_place(ctx.GetPlace())) {
      offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
      offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
      offset_data = offset_cpu.data<int64_t>();
135

136 137 138
      length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
      length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
      length_data = length_cpu.data<int64_t>();
139 140
    }

141
    auto lod = in->lod();
142
    auto out_lod = out_grad->lod();
143

W
wanghaox 已提交
144 145 146 147
    if (x_grad) {
      x_grad->mutable_data<T>(ctx.GetPlace());
      math::SetConstant<Place, T> set_zero;
      set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
148

W
wanghaox 已提交
149
      auto out_grad_stride = framework::stride(out_grad->dims());
150

W
wanghaox 已提交
151 152 153 154 155
      for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
        Tensor out_grad_t =
            out_grad->Slice(static_cast<int>(out_lod[0][i]),
                            static_cast<int>(out_lod[0][i + 1]));
        auto out_grad_stride = framework::stride(out_grad_t.dims());
156

W
wanghaox 已提交
157
        auto x_grad_stride = framework::stride(x_grad->dims());
158

W
wanghaox 已提交
159 160 161
        Tensor x_grad_t = x_grad->Slice(
            static_cast<int>(lod[0][i] + offset_data[i]),
            static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
162

W
wanghaox 已提交
163 164 165 166
        StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
                        out_grad_stride, out_grad_t.dims(), x_grad_stride,
                        x_grad_t.data<T>());
      }
167 168 169 170 171 172
    }
  }
};

}  // namespace operators
}  // namespace paddle