lstm_op.h 15.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dangqingqing 已提交
2

D
dangqingqing 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dangqingqing 已提交
6

D
dangqingqing 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
D
dangqingqing 已提交
8

D
dangqingqing 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dangqingqing 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
D
dangqingqing 已提交
21 22 23 24

namespace paddle {
namespace operators {

D
dangqingqing 已提交
25 26 27
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

Q
QI JUN 已提交
28 29
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
D
dzhwinter 已提交
30 31
                             const framework::Tensor& src,
                             framework::Vector<size_t> index_lod,
D
dangqingqing 已提交
32
                             framework::Tensor* dst, bool indexed_src) {
Q
QI JUN 已提交
33
  math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
D
dangqingqing 已提交
34
  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
D
dzhwinter 已提交
35
  row_shuffle(ctx, src, index_lod, *dst, indexed_src);
D
dangqingqing 已提交
36 37
}

Q
QI JUN 已提交
38
template <typename DeviceContext, typename T>
D
dangqingqing 已提交
39 40
class LSTMKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
41
  void Compute(const framework::ExecutionContext& ctx) const override {
D
dangqingqing 已提交
42 43 44
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");
45

46 47 48
    auto* hidden_t0 = ctx.Input<Tensor>("H0");
    auto* cell_t0 = ctx.Input<Tensor>("C0");

D
dangqingqing 已提交
49
    auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
50
    batch_gate->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
51
    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
52
    hidden_out->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
53
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
54 55
    cell_out->mutable_data<T>(ctx.GetPlace());

56
    bool is_reverse = ctx.Attr<bool>("is_reverse");
Q
QI JUN 已提交
57 58
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
    auto& device_ctx = ctx.template device_context<DeviceContext>();
D
dangqingqing 已提交
59
    to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
60 61

    auto in_dims = input->dims();
Y
Yu Yang 已提交
62
    int frame_size = static_cast<int>(in_dims[1] / 4);
63
    framework::DDim dims({in_dims[0], frame_size});
D
dangqingqing 已提交
64

65
    if (bias) {
66 67 68
      Tensor b = *bias;
      b.Resize({bias->numel(), 1});
      Tensor gate_bias = b.Slice(0, 4 * frame_size);
Q
QI JUN 已提交
69
      math::RowwiseAdd<DeviceContext, T> add_bias;
70
      add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
71 72 73
    }

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
74
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
75 76
      T* bias_data = const_cast<T*>(bias->data<T>());
      // the code style in LstmMetaValue will be updated later.
77

78 79 80
      lstm_value.check_ig = bias_data + 4 * frame_size;
      lstm_value.check_fg = lstm_value.check_ig + frame_size;
      lstm_value.check_og = lstm_value.check_fg + frame_size;
D
dangqingqing 已提交
81
    } else {
82 83 84
      lstm_value.check_ig = nullptr;
      lstm_value.check_fg = nullptr;
      lstm_value.check_og = nullptr;
D
dangqingqing 已提交
85
    }
86
    lstm_value.prev_state_value = nullptr;
87
    Tensor ordered_c0;
D
dzhwinter 已提交
88 89 90

    framework::Vector<size_t> order(batch_gate->lod()[2]);

91
    if (cell_t0) {
D
dangqingqing 已提交
92 93 94
      // Since the batch computing for LSTM reorders the input sequence
      // according to their length. The initialized cell state also needs
      // to reorder.
Q
QI JUN 已提交
95 96
      ReorderInitState<DeviceContext, T>(device_ctx, *cell_t0, order,
                                         &ordered_c0, true);
97
      lstm_value.prev_state_value = ordered_c0.data<T>();
98
    }
99

D
dangqingqing 已提交
100 101
    // Use the local variable as here.
    LoDTensor batch_hidden, batch_cell;
102
    auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
D
dangqingqing 已提交
103
    batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
104
    batch_cell.mutable_data<T>(dims, ctx.GetPlace());
105
    batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
106

D
dangqingqing 已提交
107
    auto batch_starts = batch_gate->lod()[0];
Y
Yu Yang 已提交
108
    size_t num_batch = batch_starts.size() - 1;
109 110 111 112 113 114
    auto gate_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("gate_activation"));
    auto cell_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("cell_activation"));
    auto cand_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("candidate_activation"));
115

Y
Yu Yang 已提交
116 117 118
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
119

D
dangqingqing 已提交
120
      Tensor gate_t = batch_gate->Slice(bstart, bend);
D
dangqingqing 已提交
121
      Tensor out_t = batch_hidden.Slice(bstart, bend);
D
dangqingqing 已提交
122
      Tensor cell_t = batch_cell.Slice(bstart, bend);
123
      Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
124 125 126

      int cur_batch_size = bend - bstart;

127
      if (n > 0) {
Y
Yu Yang 已提交
128
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
D
dangqingqing 已提交
129
        int pre_h_end = pre_h_start + cur_batch_size;
D
dangqingqing 已提交
130
        auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
Q
QI JUN 已提交
131 132 133
        math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight,
                                       false, static_cast<T>(1.0), &gate_t,
                                       static_cast<T>(1.0));
134
      } else if (hidden_t0) {
D
dangqingqing 已提交
135 136 137 138 139 140 141
        // If n == 0 and there is no initialized hidden state, that is to say
        // the H0 is zeros, the calculation W_h * H0 will be skiped.
        // If n == 0 and there is initialized hidden state, calculate W_h * H0.

        // Since the batch computing for LSTM reorders the input sequence
        // according to their length. The initialized hidden state also needs
        // to reorder.
142
        Tensor ordered_h0;
Q
QI JUN 已提交
143 144 145 146 147
        ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
                                           &ordered_h0, true);
        math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight,
                                       false, static_cast<T>(1.0), &gate_t,
                                       static_cast<T>(1.0));
148 149
      }

150 151 152 153
      lstm_value.gate_value = gate_t.data<T>();
      lstm_value.output_value = out_t.data<T>();
      lstm_value.state_value = cell_t.data<T>();
      lstm_value.state_active_value = cell_pre_act_t.data<T>();
Q
QI JUN 已提交
154 155 156
      math::LstmUnitFunctor<DeviceContext, T>::compute(
          device_ctx, lstm_value, frame_size, cur_batch_size, gate_act,
          cell_act, cand_act);
157
      lstm_value.prev_state_value = lstm_value.state_value;
D
dangqingqing 已提交
158
    }
159

Q
QI JUN 已提交
160
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
D
dangqingqing 已提交
161
    batch_hidden.set_lod(batch_gate->lod());
162
    // restore the output hidden in LoDTensor from the batch hidden
D
dangqingqing 已提交
163
    to_seq(device_ctx, batch_hidden, *hidden_out);
164

165
    batch_cell.set_lod(batch_gate->lod());
166
    // restore the output cell state in LoDTensor from the batch cell
D
dangqingqing 已提交
167
    to_seq(device_ctx, batch_cell, *cell_out);
D
dangqingqing 已提交
168
  }
D
dangqingqing 已提交
169 170
};

Q
QI JUN 已提交
171
template <typename DeviceContext, typename T>
D
dangqingqing 已提交
172 173
class LSTMGradKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");

    auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
    auto* cell_out = ctx.Input<LoDTensor>("Cell");

    auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
    auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");

    auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));

    auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
    auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
    auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));

191 192 193 194 195 196
    auto* h0 = ctx.Input<Tensor>("H0");
    auto* c0 = ctx.Input<Tensor>("C0");

    auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
    auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));

Q
QI JUN 已提交
197 198
    auto& device_ctx = ctx.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> zero;
D
dangqingqing 已提交
199
    if (weight_g) {
200
      weight_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
201 202 203
      zero(device_ctx, weight_g, static_cast<T>(0.0));
    }

D
dangqingqing 已提交
204 205 206
    // ordered_h0/c0 is the reordered hidden/cell initialization.
    // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
    // initialization.
207
    Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
D
dzhwinter 已提交
208 209
    framework::Vector<size_t> order(batch_gate->lod()[2]);

210
    if (c0) {
Q
QI JUN 已提交
211 212
      ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0,
                                         true);
D
dangqingqing 已提交
213 214 215
    }
    if (c0 && c0_g) {
      ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
216 217
    }

D
dangqingqing 已提交
218 219 220 221 222 223
    auto in_dims = input->dims();
    auto out_dims = hidden_g->dims();
    int frame_size = static_cast<int>(in_dims[1] / 4);
    PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
224
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
225
      T* bias_data = const_cast<T*>(bias->data<T>());
226 227 228
      lstm_value.check_ig = bias_data + 4 * frame_size;
      lstm_value.check_fg = lstm_value.check_ig + frame_size;
      lstm_value.check_og = lstm_value.check_fg + frame_size;
D
dangqingqing 已提交
229
    } else {
230 231 232
      lstm_value.check_ig = nullptr;
      lstm_value.check_fg = nullptr;
      lstm_value.check_og = nullptr;
D
dangqingqing 已提交
233 234 235
    }

    math::LstmMetaGrad<T> lstm_grad;
D
dangqingqing 已提交
236

D
dangqingqing 已提交
237
    if (bias && bias_g) {
D
dangqingqing 已提交
238
      bias_g->mutable_data<T>(ctx.GetPlace());
239
      zero(device_ctx, bias_g, static_cast<T>(0.0));
D
dangqingqing 已提交
240 241 242
    }
    if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
      T* bias_g_data = bias_g->data<T>();
243 244 245
      lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
      lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
      lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
D
dangqingqing 已提交
246
    } else {
247 248 249
      lstm_grad.check_ig_grad = nullptr;
      lstm_grad.check_fg_grad = nullptr;
      lstm_grad.check_og_grad = nullptr;
D
dangqingqing 已提交
250 251
    }

Q
QI JUN 已提交
252
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
D
dangqingqing 已提交
253

D
dangqingqing 已提交
254
    auto ToBatch = [&batch_gate, &to_batch](
Q
QI JUN 已提交
255
        const DeviceContext& ctx, const framework::LoDTensor& src,
D
dangqingqing 已提交
256 257 258 259 260
        const framework::DDim& dims, framework::LoDTensor& dst) {
      dst.mutable_data<T>(dims, ctx.GetPlace());
      dst.set_lod(batch_gate->lod());
      to_batch(ctx, src, dst, false);
    };
D
dangqingqing 已提交
261

D
dangqingqing 已提交
262 263 264 265
    LoDTensor batch_hidden, batch_hidden_g, batch_cell;
    ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
    ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
    ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
D
dangqingqing 已提交
266

D
dangqingqing 已提交
267
    LoDTensor batch_cell_g, batch_gate_g;
D
dangqingqing 已提交
268
    batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
269
    // TODO(qingqing) support the case output cell has gradient.
D
dangqingqing 已提交
270
    // to_batch(device_ctx, *cell_g, batch_cell_g, false);
271
    zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
D
dangqingqing 已提交
272 273 274
    batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
    batch_gate_g.set_lod(batch_gate->lod());

275 276 277 278 279 280
    auto gate_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("gate_activation"));
    auto cell_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("cell_activation"));
    auto cand_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("candidate_activation"));
D
dangqingqing 已提交
281 282 283

    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
284
    for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
D
dangqingqing 已提交
285 286 287 288 289 290
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);

      Tensor gate = batch_gate->Slice(bstart, bend);
      Tensor cell = batch_cell.Slice(bstart, bend);
      Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
291 292 293
      lstm_value.gate_value = gate.data<T>();
      lstm_value.state_value = cell.data<T>();
      lstm_value.state_active_value = cell_pre_act.data<T>();
D
dangqingqing 已提交
294 295 296 297

      Tensor out_g = batch_hidden_g.Slice(bstart, bend);
      Tensor gate_g = batch_gate_g.Slice(bstart, bend);
      Tensor cell_g = batch_cell_g.Slice(bstart, bend);
298 299 300
      lstm_grad.state_grad = cell_g.data<T>();
      lstm_grad.gate_grad = gate_g.data<T>();
      lstm_grad.output_grad = out_g.data<T>();
D
dangqingqing 已提交
301

302
      if (n > 0) {
D
dangqingqing 已提交
303 304 305
        int bstart_pre = static_cast<int>(batch_starts[n - 1]);
        Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
        Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
306 307
        lstm_value.prev_state_value = cell_pre.data<T>();
        lstm_grad.prev_state_grad = cell_pre_g.data<T>();
D
dangqingqing 已提交
308
      } else {
309 310
        lstm_value.prev_state_value = c0 ? ordered_c0.data<T>() : nullptr;
        lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data<T>() : nullptr;
D
dangqingqing 已提交
311 312 313
      }

      int cur_batch_size = bend - bstart;
Q
QI JUN 已提交
314
      math::LstmUnitGradFunctor<DeviceContext, T>::compute(
D
dangqingqing 已提交
315 316 317
          device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
          gate_act, cell_act, cand_act);

318
      if (n > 0) {
D
dangqingqing 已提交
319 320 321
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
        int pre_h_end = pre_h_start + cur_batch_size;
        auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
Q
QI JUN 已提交
322 323 324
        math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
                                       static_cast<T>(1.0), &pre_hidden_g,
                                       static_cast<T>(1.0));
D
dangqingqing 已提交
325 326 327
        if (weight_g) {
          /* backward weight */
          auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
Q
QI JUN 已提交
328 329 330
          math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g,
                                         false, static_cast<T>(1.0), weight_g,
                                         static_cast<T>(1.0));
D
dangqingqing 已提交
331
        }
332 333
      } else {
        if (h0 && weight_g) {
Q
QI JUN 已提交
334 335 336 337 338
          ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
                                             &ordered_h0, true);
          math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g,
                                         false, static_cast<T>(1.0), weight_g,
                                         static_cast<T>(1.0));
339 340 341
        }
        if (h0 && h0_g) {
          ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
Q
QI JUN 已提交
342 343 344
          math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
                                         true, static_cast<T>(1.0),
                                         &ordered_h0_g, static_cast<T>(0.0));
345
        }
D
dangqingqing 已提交
346 347 348
      }
    }

Q
QI JUN 已提交
349
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
D
dangqingqing 已提交
350 351
    if (in_g) {
      /* backward data */
352
      in_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
353 354 355 356
      to_seq(device_ctx, batch_gate_g, *in_g);
    }
    if (bias && bias_g) {
      /* backward bias */
357 358 359
      Tensor b_g = *bias_g;
      b_g.Resize({bias_g->numel(), 1});
      Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size);
Q
QI JUN 已提交
360
      math::ColwiseSum<DeviceContext, T> col_sum;
361
      col_sum(device_ctx, batch_gate_g, &gate_bias_g);
D
dangqingqing 已提交
362
    }
363 364

    if (h0 && h0_g) {
Q
QI JUN 已提交
365 366
      ReorderInitState<DeviceContext, T>(device_ctx, ordered_h0_g, order, h0_g,
                                         false);
367 368
    }
    if (c0 && c0_g) {
Q
QI JUN 已提交
369 370
      ReorderInitState<DeviceContext, T>(device_ctx, ordered_c0_g, order, c0_g,
                                         false);
371
    }
D
dangqingqing 已提交
372
  }
D
dangqingqing 已提交
373 374 375 376
};

}  // namespace operators
}  // namespace paddle