lstm_op.h 14.5 KB
Newer Older
D
dangqingqing 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

D
dangqingqing 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dangqingqing 已提交
6

D
dangqingqing 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
D
dangqingqing 已提交
8

D
dangqingqing 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dangqingqing 已提交
14 15 16

#pragma once
#include "paddle/framework/op_registry.h"
17 18 19
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
D
dangqingqing 已提交
20 21 22 23

namespace paddle {
namespace operators {

D
dangqingqing 已提交
24 25 26
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

D
dangqingqing 已提交
27 28 29 30 31 32 33 34 35
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
                             const framework::Tensor& src, const size_t* index,
                             framework::Tensor* dst, bool indexed_src) {
  math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
  row_shuffle(ctx, src, index, *dst, indexed_src);
}

D
dangqingqing 已提交
36 37 38
template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
39
  void Compute(const framework::ExecutionContext& ctx) const override {
D
dangqingqing 已提交
40 41 42
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");
43

44 45 46
    auto* hidden_t0 = ctx.Input<Tensor>("H0");
    auto* cell_t0 = ctx.Input<Tensor>("C0");

D
dangqingqing 已提交
47
    auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
48
    batch_gate->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
49
    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
50
    hidden_out->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
51
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
52 53
    cell_out->mutable_data<T>(ctx.GetPlace());

54
    bool is_reverse = ctx.Attr<bool>("is_reverse");
55
    math::LoDTensor2BatchFunctor<Place, T> to_batch;
D
dangqingqing 已提交
56 57
    auto& device_ctx = ctx.device_context();
    to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
58 59

    auto in_dims = input->dims();
Y
Yu Yang 已提交
60
    int frame_size = static_cast<int>(in_dims[1] / 4);
61
    framework::DDim dims({in_dims[0], frame_size});
D
dangqingqing 已提交
62

63
    if (bias) {
64 65 66 67 68
      Tensor b = *bias;
      b.Resize({bias->numel(), 1});
      Tensor gate_bias = b.Slice(0, 4 * frame_size);
      math::RowwiseAdd<Place, T> add_bias;
      add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
69 70 71
    }

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
72
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
73 74
      T* bias_data = const_cast<T*>(bias->data<T>());
      // the code style in LstmMetaValue will be updated later.
75

D
dangqingqing 已提交
76 77 78 79 80 81 82 83
      lstm_value.checkIg = bias_data + 4 * frame_size;
      lstm_value.checkFg = lstm_value.checkIg + frame_size;
      lstm_value.checkOg = lstm_value.checkFg + frame_size;
    } else {
      lstm_value.checkIg = nullptr;
      lstm_value.checkFg = nullptr;
      lstm_value.checkOg = nullptr;
    }
84
    lstm_value.prevStateValue = nullptr;
85
    Tensor ordered_c0;
D
dangqingqing 已提交
86
    const size_t* order = batch_gate->lod()[2].data();
87
    if (cell_t0) {
D
dangqingqing 已提交
88 89 90 91 92
      // Since the batch computing for LSTM reorders the input sequence
      // according to their length. The initialized cell state also needs
      // to reorder.
      ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
                                 true);
93 94
      lstm_value.prevStateValue = ordered_c0.data<T>();
    }
95

D
dangqingqing 已提交
96 97
    // Use the local variable as here.
    LoDTensor batch_hidden, batch_cell;
98
    auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
D
dangqingqing 已提交
99
    batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
100
    batch_cell.mutable_data<T>(dims, ctx.GetPlace());
101
    batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
102

D
dangqingqing 已提交
103
    auto batch_starts = batch_gate->lod()[0];
Y
Yu Yang 已提交
104
    size_t num_batch = batch_starts.size() - 1;
105 106 107
    auto gate_act = ctx.Attr<std::string>("gate_activation");
    auto cell_act = ctx.Attr<std::string>("cell_activation");
    auto cand_act = ctx.Attr<std::string>("candidate_activation");
108

Y
Yu Yang 已提交
109 110 111
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
112

D
dangqingqing 已提交
113
      Tensor gate_t = batch_gate->Slice(bstart, bend);
D
dangqingqing 已提交
114
      Tensor out_t = batch_hidden.Slice(bstart, bend);
D
dangqingqing 已提交
115
      Tensor cell_t = batch_cell.Slice(bstart, bend);
116
      Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
117 118 119

      int cur_batch_size = bend - bstart;

120
      if (n > 0) {
Y
Yu Yang 已提交
121
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
D
dangqingqing 已提交
122
        int pre_h_end = pre_h_start + cur_batch_size;
D
dangqingqing 已提交
123 124 125
        auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
        math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
                               static_cast<T>(1.0), &gate_t,
D
dangqingqing 已提交
126
                               static_cast<T>(1.0));
127
      } else if (hidden_t0) {
D
dangqingqing 已提交
128 129 130 131 132 133 134
        // If n == 0 and there is no initialized hidden state, that is to say
        // the H0 is zeros, the calculation W_h * H0 will be skiped.
        // If n == 0 and there is initialized hidden state, calculate W_h * H0.

        // Since the batch computing for LSTM reorders the input sequence
        // according to their length. The initialized hidden state also needs
        // to reorder.
135
        Tensor ordered_h0;
D
dangqingqing 已提交
136 137
        ReorderInitState<Place, T>(device_ctx, *hidden_t0, order, &ordered_h0,
                                   true);
138 139 140
        math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
                               static_cast<T>(1.0), &gate_t,
                               static_cast<T>(1.0));
141 142 143 144 145 146
      }

      lstm_value.gateValue = gate_t.data<T>();
      lstm_value.outputValue = out_t.data<T>();
      lstm_value.stateValue = cell_t.data<T>();
      lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
D
dangqingqing 已提交
147
      math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
148 149 150
                                               frame_size, cur_batch_size,
                                               gate_act, cell_act, cand_act);
      lstm_value.prevStateValue = lstm_value.stateValue;
D
dangqingqing 已提交
151
    }
152 153

    math::Batch2LoDTensorFunctor<Place, T> to_seq;
D
dangqingqing 已提交
154
    batch_hidden.set_lod(batch_gate->lod());
155
    // restore the output hidden in LoDTensor from the batch hidden
D
dangqingqing 已提交
156
    to_seq(device_ctx, batch_hidden, *hidden_out);
157

158
    batch_cell.set_lod(batch_gate->lod());
159
    // restore the output cell state in LoDTensor from the batch cell
D
dangqingqing 已提交
160
    to_seq(device_ctx, batch_cell, *cell_out);
D
dangqingqing 已提交
161
  }
D
dangqingqing 已提交
162 163 164 165 166
};

template <typename Place, typename T>
class LSTMGradKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");

    auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
    auto* cell_out = ctx.Input<LoDTensor>("Cell");

    auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
    auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");

    auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));

    auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
    auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
    auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));

184 185 186 187 188 189
    auto* h0 = ctx.Input<Tensor>("H0");
    auto* c0 = ctx.Input<Tensor>("C0");

    auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
    auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));

D
dangqingqing 已提交
190
    auto& device_ctx = ctx.device_context();
191
    math::SetConstant<Place, T> zero;
D
dangqingqing 已提交
192
    if (weight_g) {
193
      weight_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
194 195 196
      zero(device_ctx, weight_g, static_cast<T>(0.0));
    }

D
dangqingqing 已提交
197 198 199
    // ordered_h0/c0 is the reordered hidden/cell initialization.
    // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
    // initialization.
200 201 202
    Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
    const size_t* order = batch_gate->lod()[2].data();
    if (c0) {
D
dangqingqing 已提交
203 204 205 206
      ReorderInitState<Place, T>(device_ctx, *c0, order, &ordered_c0, true);
    }
    if (c0 && c0_g) {
      ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
207 208
    }

D
dangqingqing 已提交
209 210 211 212 213 214
    auto in_dims = input->dims();
    auto out_dims = hidden_g->dims();
    int frame_size = static_cast<int>(in_dims[1] / 4);
    PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
215
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
216 217 218 219 220 221 222 223 224 225 226
      T* bias_data = const_cast<T*>(bias->data<T>());
      lstm_value.checkIg = bias_data + 4 * frame_size;
      lstm_value.checkFg = lstm_value.checkIg + frame_size;
      lstm_value.checkOg = lstm_value.checkFg + frame_size;
    } else {
      lstm_value.checkIg = nullptr;
      lstm_value.checkFg = nullptr;
      lstm_value.checkOg = nullptr;
    }

    math::LstmMetaGrad<T> lstm_grad;
D
dangqingqing 已提交
227

D
dangqingqing 已提交
228
    if (bias && bias_g) {
D
dangqingqing 已提交
229
      bias_g->mutable_data<T>(ctx.GetPlace());
230
      zero(device_ctx, bias_g, static_cast<T>(0.0));
D
dangqingqing 已提交
231 232 233
    }
    if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
      T* bias_g_data = bias_g->data<T>();
D
dangqingqing 已提交
234 235 236 237 238 239 240 241 242 243 244
      lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
      lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
      lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
    } else {
      lstm_grad.checkIgGrad = nullptr;
      lstm_grad.checkFgGrad = nullptr;
      lstm_grad.checkOgGrad = nullptr;
    }

    math::LoDTensor2BatchFunctor<Place, T> to_batch;

D
dangqingqing 已提交
245 246 247 248 249 250 251
    auto ToBatch = [&batch_gate, &to_batch](
        const platform::DeviceContext& ctx, const framework::LoDTensor& src,
        const framework::DDim& dims, framework::LoDTensor& dst) {
      dst.mutable_data<T>(dims, ctx.GetPlace());
      dst.set_lod(batch_gate->lod());
      to_batch(ctx, src, dst, false);
    };
D
dangqingqing 已提交
252

D
dangqingqing 已提交
253 254 255 256
    LoDTensor batch_hidden, batch_hidden_g, batch_cell;
    ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
    ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
    ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
D
dangqingqing 已提交
257

D
dangqingqing 已提交
258
    LoDTensor batch_cell_g, batch_gate_g;
D
dangqingqing 已提交
259
    batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
260
    // TODO(qingqing) support the case output cell has gradient.
D
dangqingqing 已提交
261
    // to_batch(device_ctx, *cell_g, batch_cell_g, false);
262
    zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
D
dangqingqing 已提交
263 264 265
    batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
    batch_gate_g.set_lod(batch_gate->lod());

266 267 268
    auto gate_act = ctx.Attr<std::string>("gate_activation");
    auto cell_act = ctx.Attr<std::string>("cell_activation");
    auto cand_act = ctx.Attr<std::string>("candidate_activation");
D
dangqingqing 已提交
269 270 271

    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
272
    for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
D
dangqingqing 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);

      Tensor gate = batch_gate->Slice(bstart, bend);
      Tensor cell = batch_cell.Slice(bstart, bend);
      Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
      lstm_value.gateValue = gate.data<T>();
      lstm_value.stateValue = cell.data<T>();
      lstm_value.stateActiveValue = cell_pre_act.data<T>();

      Tensor out_g = batch_hidden_g.Slice(bstart, bend);
      Tensor gate_g = batch_gate_g.Slice(bstart, bend);
      Tensor cell_g = batch_cell_g.Slice(bstart, bend);
      lstm_grad.stateGrad = cell_g.data<T>();
      lstm_grad.gateGrad = gate_g.data<T>();
      lstm_grad.outputGrad = out_g.data<T>();

290
      if (n > 0) {
D
dangqingqing 已提交
291 292 293 294 295 296
        int bstart_pre = static_cast<int>(batch_starts[n - 1]);
        Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
        Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
        lstm_value.prevStateValue = cell_pre.data<T>();
        lstm_grad.prevStateGrad = cell_pre_g.data<T>();
      } else {
D
dangqingqing 已提交
297 298
        lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
        lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
D
dangqingqing 已提交
299 300 301 302 303 304 305
      }

      int cur_batch_size = bend - bstart;
      math::LstmUnitGradFunctor<Place, T>::compute(
          device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
          gate_act, cell_act, cand_act);

306
      if (n > 0) {
D
dangqingqing 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
        int pre_h_end = pre_h_start + cur_batch_size;
        auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
        math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
                               static_cast<T>(1.0), &pre_hidden_g,
                               static_cast<T>(1.0));
        if (weight_g) {
          /* backward weight */
          auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
          math::matmul<Place, T>(device_ctx, pre_hidden, true, gate_g, false,
                                 static_cast<T>(1.0), weight_g,
                                 static_cast<T>(1.0));
        }
320 321
      } else {
        if (h0 && weight_g) {
D
dangqingqing 已提交
322
          ReorderInitState<Place, T>(device_ctx, *h0, order, &ordered_h0, true);
323 324 325 326 327 328 329 330 331 332
          math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false,
                                 static_cast<T>(1.0), weight_g,
                                 static_cast<T>(1.0));
        }
        if (h0 && h0_g) {
          ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
          math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
                                 static_cast<T>(1.0), &ordered_h0_g,
                                 static_cast<T>(0.0));
        }
D
dangqingqing 已提交
333 334 335 336 337 338
      }
    }

    math::Batch2LoDTensorFunctor<Place, T> to_seq;
    if (in_g) {
      /* backward data */
339
      in_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
340 341 342 343
      to_seq(device_ctx, batch_gate_g, *in_g);
    }
    if (bias && bias_g) {
      /* backward bias */
344 345 346 347
      int m = static_cast<int>(batch_gate_g.dims()[0]);
      int n = static_cast<int>(batch_gate_g.dims()[1]);

      Tensor ones;
348
      ones.mutable_data<T>({m}, ctx.GetPlace());
349 350 351 352 353
      math::SetConstant<Place, T> set;
      set(device_ctx, &ones, static_cast<T>(1.0));

      math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(),
                           ones.data<T>(), 0., bias_g->data<T>());
D
dangqingqing 已提交
354
    }
355 356

    if (h0 && h0_g) {
D
dangqingqing 已提交
357
      ReorderInitState<Place, T>(device_ctx, ordered_h0_g, order, h0_g, false);
358 359
    }
    if (c0 && c0_g) {
D
dangqingqing 已提交
360
      ReorderInitState<Place, T>(device_ctx, ordered_c0_g, order, c0_g, false);
361
    }
D
dangqingqing 已提交
362
  }
D
dangqingqing 已提交
363 364 365 366
};

}  // namespace operators
}  // namespace paddle