lstm_op.h 15.1 KB
Newer Older
D
dangqingqing 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

D
dangqingqing 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dangqingqing 已提交
6

D
dangqingqing 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
D
dangqingqing 已提交
8

D
dangqingqing 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dangqingqing 已提交
14 15 16

#pragma once
#include "paddle/framework/op_registry.h"
D
dangqingqing 已提交
17
#include "paddle/operators/math/detail/activation_functions.h"
18 19 20
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
D
dangqingqing 已提交
21 22 23 24

namespace paddle {
namespace operators {

D
dangqingqing 已提交
25 26 27
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

Q
QI JUN 已提交
28 29
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
D
dangqingqing 已提交
30 31
                             const framework::Tensor& src, const size_t* index,
                             framework::Tensor* dst, bool indexed_src) {
Q
QI JUN 已提交
32
  math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
D
dangqingqing 已提交
33 34 35 36
  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
  row_shuffle(ctx, src, index, *dst, indexed_src);
}

Q
QI JUN 已提交
37
template <typename DeviceContext, typename T>
D
dangqingqing 已提交
38 39
class LSTMKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
40
  void Compute(const framework::ExecutionContext& ctx) const override {
D
dangqingqing 已提交
41 42 43
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");
44

45 46 47
    auto* hidden_t0 = ctx.Input<Tensor>("H0");
    auto* cell_t0 = ctx.Input<Tensor>("C0");

D
dangqingqing 已提交
48
    auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
49
    batch_gate->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
50
    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
51
    hidden_out->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
52
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
53 54
    cell_out->mutable_data<T>(ctx.GetPlace());

55
    bool is_reverse = ctx.Attr<bool>("is_reverse");
Q
QI JUN 已提交
56 57
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
    auto& device_ctx = ctx.template device_context<DeviceContext>();
D
dangqingqing 已提交
58
    to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
59 60

    auto in_dims = input->dims();
Y
Yu Yang 已提交
61
    int frame_size = static_cast<int>(in_dims[1] / 4);
62
    framework::DDim dims({in_dims[0], frame_size});
D
dangqingqing 已提交
63

64
    if (bias) {
65 66 67
      Tensor b = *bias;
      b.Resize({bias->numel(), 1});
      Tensor gate_bias = b.Slice(0, 4 * frame_size);
Q
QI JUN 已提交
68
      math::RowwiseAdd<DeviceContext, T> add_bias;
69
      add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
70 71 72
    }

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
73
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
74 75
      T* bias_data = const_cast<T*>(bias->data<T>());
      // the code style in LstmMetaValue will be updated later.
76

77 78 79
      lstm_value.check_ig = bias_data + 4 * frame_size;
      lstm_value.check_fg = lstm_value.check_ig + frame_size;
      lstm_value.check_og = lstm_value.check_fg + frame_size;
D
dangqingqing 已提交
80
    } else {
81 82 83
      lstm_value.check_ig = nullptr;
      lstm_value.check_fg = nullptr;
      lstm_value.check_og = nullptr;
D
dangqingqing 已提交
84
    }
85
    lstm_value.prev_state_value = nullptr;
86
    Tensor ordered_c0;
D
dangqingqing 已提交
87
    const size_t* order = batch_gate->lod()[2].data();
88
    if (cell_t0) {
D
dangqingqing 已提交
89 90 91
      // Since the batch computing for LSTM reorders the input sequence
      // according to their length. The initialized cell state also needs
      // to reorder.
Q
QI JUN 已提交
92 93
      ReorderInitState<DeviceContext, T>(device_ctx, *cell_t0, order,
                                         &ordered_c0, true);
94
      lstm_value.prev_state_value = ordered_c0.data<T>();
95
    }
96

D
dangqingqing 已提交
97 98
    // Use the local variable as here.
    LoDTensor batch_hidden, batch_cell;
99
    auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
D
dangqingqing 已提交
100
    batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
101
    batch_cell.mutable_data<T>(dims, ctx.GetPlace());
102
    batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
103

D
dangqingqing 已提交
104
    auto batch_starts = batch_gate->lod()[0];
Y
Yu Yang 已提交
105
    size_t num_batch = batch_starts.size() - 1;
106 107 108 109 110 111
    auto gate_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("gate_activation"));
    auto cell_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("cell_activation"));
    auto cand_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("candidate_activation"));
112

Y
Yu Yang 已提交
113 114 115
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
116

D
dangqingqing 已提交
117
      Tensor gate_t = batch_gate->Slice(bstart, bend);
D
dangqingqing 已提交
118
      Tensor out_t = batch_hidden.Slice(bstart, bend);
D
dangqingqing 已提交
119
      Tensor cell_t = batch_cell.Slice(bstart, bend);
120
      Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
121 122 123

      int cur_batch_size = bend - bstart;

124
      if (n > 0) {
Y
Yu Yang 已提交
125
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
D
dangqingqing 已提交
126
        int pre_h_end = pre_h_start + cur_batch_size;
D
dangqingqing 已提交
127
        auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
Q
QI JUN 已提交
128 129 130
        math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight,
                                       false, static_cast<T>(1.0), &gate_t,
                                       static_cast<T>(1.0));
131
      } else if (hidden_t0) {
D
dangqingqing 已提交
132 133 134 135 136 137 138
        // If n == 0 and there is no initialized hidden state, that is to say
        // the H0 is zeros, the calculation W_h * H0 will be skiped.
        // If n == 0 and there is initialized hidden state, calculate W_h * H0.

        // Since the batch computing for LSTM reorders the input sequence
        // according to their length. The initialized hidden state also needs
        // to reorder.
139
        Tensor ordered_h0;
Q
QI JUN 已提交
140 141 142 143 144
        ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
                                           &ordered_h0, true);
        math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight,
                                       false, static_cast<T>(1.0), &gate_t,
                                       static_cast<T>(1.0));
145 146
      }

147 148 149 150
      lstm_value.gate_value = gate_t.data<T>();
      lstm_value.output_value = out_t.data<T>();
      lstm_value.state_value = cell_t.data<T>();
      lstm_value.state_active_value = cell_pre_act_t.data<T>();
Q
QI JUN 已提交
151 152 153
      math::LstmUnitFunctor<DeviceContext, T>::compute(
          device_ctx, lstm_value, frame_size, cur_batch_size, gate_act,
          cell_act, cand_act);
154
      lstm_value.prev_state_value = lstm_value.state_value;
D
dangqingqing 已提交
155
    }
156

Q
QI JUN 已提交
157
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
D
dangqingqing 已提交
158
    batch_hidden.set_lod(batch_gate->lod());
159
    // restore the output hidden in LoDTensor from the batch hidden
D
dangqingqing 已提交
160
    to_seq(device_ctx, batch_hidden, *hidden_out);
161

162
    batch_cell.set_lod(batch_gate->lod());
163
    // restore the output cell state in LoDTensor from the batch cell
D
dangqingqing 已提交
164
    to_seq(device_ctx, batch_cell, *cell_out);
D
dangqingqing 已提交
165
  }
D
dangqingqing 已提交
166 167
};

Q
QI JUN 已提交
168
template <typename DeviceContext, typename T>
D
dangqingqing 已提交
169 170
class LSTMGradKernel : public framework::OpKernel<T> {
 public:
D
dangqingqing 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<LoDTensor>("Input");
    auto* weight = ctx.Input<Tensor>("Weight");
    auto* bias = ctx.Input<Tensor>("Bias");

    auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
    auto* cell_out = ctx.Input<LoDTensor>("Cell");

    auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
    auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");

    auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));

    auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
    auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
    auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));

188 189 190 191 192 193
    auto* h0 = ctx.Input<Tensor>("H0");
    auto* c0 = ctx.Input<Tensor>("C0");

    auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
    auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));

Q
QI JUN 已提交
194 195
    auto& device_ctx = ctx.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> zero;
D
dangqingqing 已提交
196
    if (weight_g) {
197
      weight_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
198 199 200
      zero(device_ctx, weight_g, static_cast<T>(0.0));
    }

D
dangqingqing 已提交
201 202 203
    // ordered_h0/c0 is the reordered hidden/cell initialization.
    // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
    // initialization.
204 205 206
    Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
    const size_t* order = batch_gate->lod()[2].data();
    if (c0) {
Q
QI JUN 已提交
207 208
      ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0,
                                         true);
D
dangqingqing 已提交
209 210 211
    }
    if (c0 && c0_g) {
      ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
212 213
    }

D
dangqingqing 已提交
214 215 216 217 218 219
    auto in_dims = input->dims();
    auto out_dims = hidden_g->dims();
    int frame_size = static_cast<int>(in_dims[1] / 4);
    PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);

    math::LstmMetaValue<T> lstm_value;
D
dangqingqing 已提交
220
    if (bias && ctx.Attr<bool>("use_peepholes")) {
D
dangqingqing 已提交
221
      T* bias_data = const_cast<T*>(bias->data<T>());
222 223 224
      lstm_value.check_ig = bias_data + 4 * frame_size;
      lstm_value.check_fg = lstm_value.check_ig + frame_size;
      lstm_value.check_og = lstm_value.check_fg + frame_size;
D
dangqingqing 已提交
225
    } else {
226 227 228
      lstm_value.check_ig = nullptr;
      lstm_value.check_fg = nullptr;
      lstm_value.check_og = nullptr;
D
dangqingqing 已提交
229 230 231
    }

    math::LstmMetaGrad<T> lstm_grad;
D
dangqingqing 已提交
232

D
dangqingqing 已提交
233
    if (bias && bias_g) {
D
dangqingqing 已提交
234
      bias_g->mutable_data<T>(ctx.GetPlace());
235
      zero(device_ctx, bias_g, static_cast<T>(0.0));
D
dangqingqing 已提交
236 237 238
    }
    if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
      T* bias_g_data = bias_g->data<T>();
239 240 241
      lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
      lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
      lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
D
dangqingqing 已提交
242
    } else {
243 244 245
      lstm_grad.check_ig_grad = nullptr;
      lstm_grad.check_fg_grad = nullptr;
      lstm_grad.check_og_grad = nullptr;
D
dangqingqing 已提交
246 247
    }

Q
QI JUN 已提交
248
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
D
dangqingqing 已提交
249

D
dangqingqing 已提交
250
    auto ToBatch = [&batch_gate, &to_batch](
Q
QI JUN 已提交
251
        const DeviceContext& ctx, const framework::LoDTensor& src,
D
dangqingqing 已提交
252 253 254 255 256
        const framework::DDim& dims, framework::LoDTensor& dst) {
      dst.mutable_data<T>(dims, ctx.GetPlace());
      dst.set_lod(batch_gate->lod());
      to_batch(ctx, src, dst, false);
    };
D
dangqingqing 已提交
257

D
dangqingqing 已提交
258 259 260 261
    LoDTensor batch_hidden, batch_hidden_g, batch_cell;
    ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
    ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
    ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
D
dangqingqing 已提交
262

D
dangqingqing 已提交
263
    LoDTensor batch_cell_g, batch_gate_g;
D
dangqingqing 已提交
264
    batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
265
    // TODO(qingqing) support the case output cell has gradient.
D
dangqingqing 已提交
266
    // to_batch(device_ctx, *cell_g, batch_cell_g, false);
267
    zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
D
dangqingqing 已提交
268 269 270
    batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
    batch_gate_g.set_lod(batch_gate->lod());

271 272 273 274 275 276
    auto gate_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("gate_activation"));
    auto cell_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("cell_activation"));
    auto cand_act = math::detail::GetActivationType(
        ctx.Attr<std::string>("candidate_activation"));
D
dangqingqing 已提交
277 278 279

    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
280
    for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
D
dangqingqing 已提交
281 282 283 284 285 286
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);

      Tensor gate = batch_gate->Slice(bstart, bend);
      Tensor cell = batch_cell.Slice(bstart, bend);
      Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
287 288 289
      lstm_value.gate_value = gate.data<T>();
      lstm_value.state_value = cell.data<T>();
      lstm_value.state_active_value = cell_pre_act.data<T>();
D
dangqingqing 已提交
290 291 292 293

      Tensor out_g = batch_hidden_g.Slice(bstart, bend);
      Tensor gate_g = batch_gate_g.Slice(bstart, bend);
      Tensor cell_g = batch_cell_g.Slice(bstart, bend);
294 295 296
      lstm_grad.state_grad = cell_g.data<T>();
      lstm_grad.gate_grad = gate_g.data<T>();
      lstm_grad.output_grad = out_g.data<T>();
D
dangqingqing 已提交
297

298
      if (n > 0) {
D
dangqingqing 已提交
299 300 301
        int bstart_pre = static_cast<int>(batch_starts[n - 1]);
        Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
        Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
302 303
        lstm_value.prev_state_value = cell_pre.data<T>();
        lstm_grad.prev_state_grad = cell_pre_g.data<T>();
D
dangqingqing 已提交
304
      } else {
305 306
        lstm_value.prev_state_value = c0 ? ordered_c0.data<T>() : nullptr;
        lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data<T>() : nullptr;
D
dangqingqing 已提交
307 308 309
      }

      int cur_batch_size = bend - bstart;
Q
QI JUN 已提交
310
      math::LstmUnitGradFunctor<DeviceContext, T>::compute(
D
dangqingqing 已提交
311 312 313
          device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
          gate_act, cell_act, cand_act);

314
      if (n > 0) {
D
dangqingqing 已提交
315 316 317
        int pre_h_start = static_cast<int>(batch_starts[n - 1]);
        int pre_h_end = pre_h_start + cur_batch_size;
        auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
Q
QI JUN 已提交
318 319 320
        math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
                                       static_cast<T>(1.0), &pre_hidden_g,
                                       static_cast<T>(1.0));
D
dangqingqing 已提交
321 322 323
        if (weight_g) {
          /* backward weight */
          auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
Q
QI JUN 已提交
324 325 326
          math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g,
                                         false, static_cast<T>(1.0), weight_g,
                                         static_cast<T>(1.0));
D
dangqingqing 已提交
327
        }
328 329
      } else {
        if (h0 && weight_g) {
Q
QI JUN 已提交
330 331 332 333 334
          ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
                                             &ordered_h0, true);
          math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g,
                                         false, static_cast<T>(1.0), weight_g,
                                         static_cast<T>(1.0));
335 336 337
        }
        if (h0 && h0_g) {
          ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
Q
QI JUN 已提交
338 339 340
          math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
                                         true, static_cast<T>(1.0),
                                         &ordered_h0_g, static_cast<T>(0.0));
341
        }
D
dangqingqing 已提交
342 343 344
      }
    }

Q
QI JUN 已提交
345
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
D
dangqingqing 已提交
346 347
    if (in_g) {
      /* backward data */
348
      in_g->mutable_data<T>(ctx.GetPlace());
D
dangqingqing 已提交
349 350 351 352
      to_seq(device_ctx, batch_gate_g, *in_g);
    }
    if (bias && bias_g) {
      /* backward bias */
353 354 355
      Tensor b_g = *bias_g;
      b_g.Resize({bias_g->numel(), 1});
      Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size);
Q
QI JUN 已提交
356
      math::ColwiseSum<DeviceContext, T> col_sum;
357
      col_sum(device_ctx, batch_gate_g, &gate_bias_g);
D
dangqingqing 已提交
358
    }
359 360

    if (h0 && h0_g) {
Q
QI JUN 已提交
361 362
      ReorderInitState<DeviceContext, T>(device_ctx, ordered_h0_g, order, h0_g,
                                         false);
363 364
    }
    if (c0 && c0_g) {
Q
QI JUN 已提交
365 366
      ReorderInitState<DeviceContext, T>(device_ctx, ordered_c0_g, order, c0_g,
                                         false);
367
    }
D
dangqingqing 已提交
368
  }
D
dangqingqing 已提交
369 370 371 372
};

}  // namespace operators
}  // namespace paddle