lstm_op.h 15.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dangqingqing 已提交
2

D
dangqingqing 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dangqingqing 已提交
6

D
dangqingqing 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
D
dangqingqing 已提交
8

D
dangqingqing 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dangqingqing 已提交
14 15

#pragma once
16
#include <string>
Y
Yi Wang 已提交
17
#include "paddle/fluid/framework/op_registry.h"
Y
Yu Yang 已提交
18
#include "paddle/fluid/operators/math/blas.h"
Y
Yi Wang 已提交
19 20 21
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
D
dangqingqing 已提交
22 23 24 25

namespace paddle {
namespace operators {

D
dangqingqing 已提交
26 27 28
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

Q
QI JUN 已提交
29 30
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
D
dzhwinter 已提交
31 32
                             const framework::Tensor& src,
                             framework::Vector<size_t> index_lod,
D
dangqingqing 已提交
33
                             framework::Tensor* dst, bool indexed_src) {
Q
QI JUN 已提交
34
  math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
D
dangqingqing 已提交
35
  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
36
  row_shuffle(ctx, src, index_lod, dst, indexed_src);
D
dangqingqing 已提交
37 38
}

Q
QI JUN 已提交
39
template <typename DeviceContext, typename T>
D
dangqingqing 已提交
40 41
class LSTMKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
42
  void Compute(const framework::ExecutionContext& ctx) const override {
D
dangqingqing 已提交
43 44 45
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");
46

47 48 49
    auto* hidden_t0 = ctx.Input<Tensor>("H0");
    auto* cell_t0 = ctx.Input<Tensor>("C0");

D
dangqingqing 已提交
50
    auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
51
    batch_gate->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
52
    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
53
    hidden_out->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
54
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
55 56
    cell_out->mutable_data<T>(ctx.GetPlace());

57
    bool is_reverse = ctx.Attr<bool>("is_reverse");
Q
QI JUN 已提交
58 59
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
    auto& device_ctx = ctx.template device_context<DeviceContext>();
60
    to_batch(device_ctx, *input, batch_gate, true, is_reverse);
61 62

    auto in_dims = input->dims();
Y
Yu Yang 已提交
63
    int frame_size = static_cast<int>(in_dims[1] / 4);
64
    framework::DDim dims({in_dims[0], frame_size});
D
dangqingqing 已提交
65

66
    if (bias) {
67 68 69
      Tensor b = *bias;
      b.Resize({bias->numel(), 1});
      Tensor gate_bias = b.Slice(0, 4 * frame_size);
Q
QI JUN 已提交
70
      math::RowwiseAdd<DeviceContext, T> add_bias;
71
      add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
72 73 74
    }

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
75
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
76 77
      T* bias_data = const_cast<T*>(bias->data<T>());
      // the code style in LstmMetaValue will be updated later.
78

79 80 81
      lstm_value.check_ig = bias_data + 4 * frame_size;
      lstm_value.check_fg = lstm_value.check_ig + frame_size;
      lstm_value.check_og = lstm_value.check_fg + frame_size;
D
dangqingqing 已提交
82
    } else {
83 84 85
      lstm_value.check_ig = nullptr;
      lstm_value.check_fg = nullptr;
      lstm_value.check_og = nullptr;
D
dangqingqing 已提交
86
    }
87
    lstm_value.prev_state_value = nullptr;
88
    Tensor ordered_c0;
D
dzhwinter 已提交
89 90 91

    framework::Vector<size_t> order(batch_gate->lod()[2]);

92
    if (cell_t0) {
D
dangqingqing 已提交
93 94 95
      // Since the batch computing for LSTM reorders the input sequence
      // according to their length. The initialized cell state also needs
      // to reorder.
Q
QI JUN 已提交
96 97
      ReorderInitState<DeviceContext, T>(device_ctx, *cell_t0, order,
                                         &ordered_c0, true);
98
      lstm_value.prev_state_value = ordered_c0.data<T>();
99
    }
100

D
dangqingqing 已提交
101 102
    // Use the local variable as here.
    LoDTensor batch_hidden, batch_cell;
103
    auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
D
dangqingqing 已提交
104
    batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
105
    batch_cell.mutable_data<T>(dims, ctx.GetPlace());
106
    batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
107

D
dangqingqing 已提交
108
    auto batch_starts = batch_gate->lod()[0];
Y
Yu Yang 已提交
109
    size_t num_batch = batch_starts.size() - 1;
110 111 112 113 114 115
    auto gate_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("gate_activation"));
    auto cell_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("cell_activation"));
    auto cand_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("candidate_activation"));
116

Y
Yu Yang 已提交
117
    auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
Y
Yu Yang 已提交
118 119 120
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
121

D
dangqingqing 已提交
122
      Tensor gate_t = batch_gate->Slice(bstart, bend);
D
dangqingqing 已提交
123
      Tensor out_t = batch_hidden.Slice(bstart, bend);
D
dangqingqing 已提交
124
      Tensor cell_t = batch_cell.Slice(bstart, bend);
125
      Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
126 127 128

      int cur_batch_size = bend - bstart;

129
      if (n > 0) {
Y
Yu Yang 已提交
130
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
D
dangqingqing 已提交
131
        int pre_h_end = pre_h_start + cur_batch_size;
D
dangqingqing 已提交
132
        auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
Y
Yu Yang 已提交
133 134
        blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
                    &gate_t, static_cast<T>(1.0));
135
      } else if (hidden_t0) {
D
dangqingqing 已提交
136 137 138 139 140 141 142
        // If n == 0 and there is no initialized hidden state, that is to say
        // the H0 is zeros, the calculation W_h * H0 will be skiped.
        // If n == 0 and there is initialized hidden state, calculate W_h * H0.

        // Since the batch computing for LSTM reorders the input sequence
        // according to their length. The initialized hidden state also needs
        // to reorder.
143
        Tensor ordered_h0;
Q
QI JUN 已提交
144 145
        ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
                                           &ordered_h0, true);
Y
Yu Yang 已提交
146 147
        blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
                    &gate_t, static_cast<T>(1.0));
148 149
      }

150 151 152 153
      lstm_value.gate_value = gate_t.data<T>();
      lstm_value.output_value = out_t.data<T>();
      lstm_value.state_value = cell_t.data<T>();
      lstm_value.state_active_value = cell_pre_act_t.data<T>();
154
      T cell_clip = 0.0;
Q
QI JUN 已提交
155
      math::LstmUnitFunctor<DeviceContext, T>::compute(
156 157
          device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip,
          gate_act, cell_act, cand_act);
158
      lstm_value.prev_state_value = lstm_value.state_value;
D
dangqingqing 已提交
159
    }
160

Q
QI JUN 已提交
161
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
D
dangqingqing 已提交
162
    batch_hidden.set_lod(batch_gate->lod());
163
    // restore the output hidden in LoDTensor from the batch hidden
164
    to_seq(device_ctx, batch_hidden, hidden_out);
165

166
    batch_cell.set_lod(batch_gate->lod());
167
    // restore the output cell state in LoDTensor from the batch cell
168
    to_seq(device_ctx, batch_cell, cell_out);
D
dangqingqing 已提交
169
  }
D
dangqingqing 已提交
170 171
};

Q
QI JUN 已提交
172
template <typename DeviceContext, typename T>
D
dangqingqing 已提交
173 174
class LSTMGradKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");

    auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
    auto* cell_out = ctx.Input<LoDTensor>("Cell");

    auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
    auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");

    auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));

    auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
    auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
    auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));

192 193 194 195 196 197
    auto* h0 = ctx.Input<Tensor>("H0");
    auto* c0 = ctx.Input<Tensor>("C0");

    auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
    auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));

Q
QI JUN 已提交
198 199
    auto& device_ctx = ctx.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> zero;
D
dangqingqing 已提交
200
    if (weight_g) {
201
      weight_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
202 203 204
      zero(device_ctx, weight_g, static_cast<T>(0.0));
    }

D
dangqingqing 已提交
205 206 207
    // ordered_h0/c0 is the reordered hidden/cell initialization.
    // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
    // initialization.
208
    Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
D
dzhwinter 已提交
209 210
    framework::Vector<size_t> order(batch_gate->lod()[2]);

211
    if (c0) {
Q
QI JUN 已提交
212 213
      ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0,
                                         true);
D
dangqingqing 已提交
214 215 216
    }
    if (c0 && c0_g) {
      ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
217 218
    }

D
dangqingqing 已提交
219 220 221
    auto in_dims = input->dims();
    auto out_dims = hidden_g->dims();
    int frame_size = static_cast<int>(in_dims[1] / 4);
222 223 224 225 226 227 228
    PADDLE_ENFORCE_EQ(
        frame_size, out_dims[1],
        platform::errors::InvalidArgument(
            "The second dimension of Input(" +
                framework::GradVarName("Hidden") +
                ") should be %d, but received %d in LSTM@Grad operator.",
            frame_size, out_dims[1]));
D
dangqingqing 已提交
229 230

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
231
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
232
      T* bias_data = const_cast<T*>(bias->data<T>());
233 234 235
      lstm_value.check_ig = bias_data + 4 * frame_size;
      lstm_value.check_fg = lstm_value.check_ig + frame_size;
      lstm_value.check_og = lstm_value.check_fg + frame_size;
D
dangqingqing 已提交
236
    } else {
237 238 239
      lstm_value.check_ig = nullptr;
      lstm_value.check_fg = nullptr;
      lstm_value.check_og = nullptr;
D
dangqingqing 已提交
240 241 242
    }

    math::LstmMetaGrad<T> lstm_grad;
D
dangqingqing 已提交
243

D
dangqingqing 已提交
244
    if (bias && bias_g) {
D
dangqingqing 已提交
245
      bias_g->mutable_data<T>(ctx.GetPlace());
246
      zero(device_ctx, bias_g, static_cast<T>(0.0));
D
dangqingqing 已提交
247 248 249
    }
    if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
      T* bias_g_data = bias_g->data<T>();
250 251 252
      lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
      lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
      lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
D
dangqingqing 已提交
253
    } else {
254 255 256
      lstm_grad.check_ig_grad = nullptr;
      lstm_grad.check_fg_grad = nullptr;
      lstm_grad.check_og_grad = nullptr;
D
dangqingqing 已提交
257 258
    }

Q
QI JUN 已提交
259
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
D
dangqingqing 已提交
260

D
dangqingqing 已提交
261
    auto ToBatch = [&batch_gate, &to_batch](
Q
QI JUN 已提交
262
        const DeviceContext& ctx, const framework::LoDTensor& src,
D
dangqingqing 已提交
263 264 265
        const framework::DDim& dims, framework::LoDTensor& dst) {
      dst.mutable_data<T>(dims, ctx.GetPlace());
      dst.set_lod(batch_gate->lod());
266
      to_batch(ctx, src, &dst, false);
D
dangqingqing 已提交
267
    };
D
dangqingqing 已提交
268

D
dangqingqing 已提交
269 270 271 272
    LoDTensor batch_hidden, batch_hidden_g, batch_cell;
    ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
    ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
    ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
D
dangqingqing 已提交
273

D
dangqingqing 已提交
274
    LoDTensor batch_cell_g, batch_gate_g;
D
dangqingqing 已提交
275
    batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
276
    // TODO(qingqing) support the case output cell has gradient.
D
dangqingqing 已提交
277
    // to_batch(device_ctx, *cell_g, batch_cell_g, false);
278
    zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
D
dangqingqing 已提交
279 280 281
    batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
    batch_gate_g.set_lod(batch_gate->lod());

282 283 284 285 286 287
    auto gate_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("gate_activation"));
    auto cell_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("cell_activation"));
    auto cand_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("candidate_activation"));
D
dangqingqing 已提交
288 289 290

    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
Y
Yu Yang 已提交
291
    auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
292
    for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
D
dangqingqing 已提交
293 294 295 296 297 298
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);

      Tensor gate = batch_gate->Slice(bstart, bend);
      Tensor cell = batch_cell.Slice(bstart, bend);
      Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
299 300 301
      lstm_value.gate_value = gate.data<T>();
      lstm_value.state_value = cell.data<T>();
      lstm_value.state_active_value = cell_pre_act.data<T>();
D
dangqingqing 已提交
302 303 304 305

      Tensor out_g = batch_hidden_g.Slice(bstart, bend);
      Tensor gate_g = batch_gate_g.Slice(bstart, bend);
      Tensor cell_g = batch_cell_g.Slice(bstart, bend);
306 307 308
      lstm_grad.state_grad = cell_g.data<T>();
      lstm_grad.gate_grad = gate_g.data<T>();
      lstm_grad.output_grad = out_g.data<T>();
D
dangqingqing 已提交
309

310
      if (n > 0) {
D
dangqingqing 已提交
311 312 313
        int bstart_pre = static_cast<int>(batch_starts[n - 1]);
        Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
        Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
314 315
        lstm_value.prev_state_value = cell_pre.data<T>();
        lstm_grad.prev_state_grad = cell_pre_g.data<T>();
D
dangqingqing 已提交
316
      } else {
317 318
        lstm_value.prev_state_value = c0 ? ordered_c0.data<T>() : nullptr;
        lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data<T>() : nullptr;
D
dangqingqing 已提交
319 320
      }

321 322
      // lstm_value.output_value not used in bp, set to nullptr
      // lstm_grad.state_active_grad not used in bp, set to nullptr
L
liuhongyu 已提交
323 324
      lstm_value.output_value = nullptr;
      lstm_grad.state_active_grad = nullptr;
D
dangqingqing 已提交
325
      int cur_batch_size = bend - bstart;
326
      T cell_clip = 0.0;
Q
QI JUN 已提交
327
      math::LstmUnitGradFunctor<DeviceContext, T>::compute(
D
dangqingqing 已提交
328
          device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
329
          cell_clip, gate_act, cell_act, cand_act);
D
dangqingqing 已提交
330

331
      if (n > 0) {
D
dangqingqing 已提交
332 333 334
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
        int pre_h_end = pre_h_start + cur_batch_size;
        auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
Y
Yu Yang 已提交
335 336
        blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
                    &pre_hidden_g, static_cast<T>(1.0));
D
dangqingqing 已提交
337 338 339
        if (weight_g) {
          /* backward weight */
          auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
Y
Yu Yang 已提交
340 341
          blas.MatMul(pre_hidden, true, gate_g, false, static_cast<T>(1.0),
                      weight_g, static_cast<T>(1.0));
D
dangqingqing 已提交
342
        }
343 344
      } else {
        if (h0 && weight_g) {
Q
QI JUN 已提交
345 346
          ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
                                             &ordered_h0, true);
Y
Yu Yang 已提交
347 348
          blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
                      weight_g, static_cast<T>(1.0));
349 350 351
        }
        if (h0 && h0_g) {
          ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
Y
Yu Yang 已提交
352 353
          blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
                      &ordered_h0_g, static_cast<T>(0.0));
354
        }
D
dangqingqing 已提交
355 356 357
      }
    }

Q
QI JUN 已提交
358
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
D
dangqingqing 已提交
359 360
    if (in_g) {
      /* backward data */
361
      in_g->mutable_data<T>(ctx.GetPlace());
362
      to_seq(device_ctx, batch_gate_g, in_g);
D
dangqingqing 已提交
363 364 365
    }
    if (bias && bias_g) {
      /* backward bias */
366 367 368
      Tensor b_g = *bias_g;
      b_g.Resize({bias_g->numel(), 1});
      Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size);
Q
QI JUN 已提交
369
      math::ColwiseSum<DeviceContext, T> col_sum;
370
      col_sum(device_ctx, batch_gate_g, &gate_bias_g);
D
dangqingqing 已提交
371
    }
372 373

    if (h0 && h0_g) {
Q
QI JUN 已提交
374 375
      ReorderInitState<DeviceContext, T>(device_ctx, ordered_h0_g, order, h0_g,
                                         false);
376 377
    }
    if (c0 && c0_g) {
Q
QI JUN 已提交
378 379
      ReorderInitState<DeviceContext, T>(device_ctx, ordered_c0_g, order, c0_g,
                                         false);
380
    }
D
dangqingqing 已提交
381
  }
D
dangqingqing 已提交
382 383 384 385
};

}  // namespace operators
}  // namespace paddle