cudnn_lstm_op.cu.cc 12.3 KB
Newer Older
P
phlrain 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
liuhongyu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduozh 已提交
15
#include "paddle/fluid/framework/op_registry.h"
S
sneaxiy 已提交
16
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
C
chengduozh 已提交
17
#include "paddle/fluid/operators/math/math_function.h"
G
GaoWei8 已提交
18
#include "paddle/fluid/platform/cudnn_desc.h"
G
GaoWei8 已提交
19
#include "paddle/fluid/platform/cudnn_helper.h"
L
liuhongyu 已提交
20 21 22 23 24 25 26

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

C
chengduozh 已提交
27
template <typename T>
L
liuhongyu 已提交
28 29 30 31 32 33 34 35 36 37
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const Tensor *x = ctx.Input<Tensor>("Input");
    const Tensor *init_h = ctx.Input<Tensor>("InitH");
    const Tensor *init_c = ctx.Input<Tensor>("InitC");

    auto w = ctx.Input<Tensor>("W");

    Tensor *out = ctx.Output<Tensor>("Out");
G
GaoWei8 已提交
38 39 40 41
    Tensor *last_h = ctx.Output<Tensor>("LastH");
    Tensor *last_c = ctx.Output<Tensor>("LastC");
    Tensor *reserve = ctx.Output<Tensor>("Reserve");
    Tensor *state_out = ctx.Output<Tensor>("StateOut");
L
liuhongyu 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

    const T *x_data = x->data<T>();
    const T *init_h_data = init_h->data<T>();
    const T *init_c_data = init_c->data<T>();

    const T *w_data = w->data<T>();

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
    T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    bool is_test = ctx.Attr<bool>("is_test");
G
GaoWei8 已提交
58
    int seed = ctx.Attr<int>("seed");
G
GaoWei8 已提交
59
    auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
L
liuhongyu 已提交
60 61 62 63

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

G
GaoWei8 已提交
64 65 66 67 68
    int seq_length = x->dims()[0];
    int batch_size = x->dims()[1];
    int input_size = x->dims()[2];
    int weight_numel = w->numel();
    bool state_initialized = state_out->IsInitialized() ? true : false;
G
GaoWei8 已提交
69

G
GaoWei8 已提交
70
    size_t workspace_size;
G
GaoWei8 已提交
71
    size_t reserve_size;
G
GaoWei8 已提交
72 73 74 75 76 77 78 79 80 81

    platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                                num_layers, dropout_prob, seed, weight_numel,
                                state_initialized, is_bidirec);
    rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
                  &reserve_size, state_out);

    framework::Tensor workspace_data_;
    workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
    workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
G
GaoWei8 已提交
82 83 84

    auto *reserve_data = reserve->mutable_data<uint8_t>(
        {static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
L
liuhongyu 已提交
85 86

    if (is_test) {
G
GaoWei8 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
      if (sequence_length.empty()) {
        // for inference
        // This interface is used when the input/output is unpadded.
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
            handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
            rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
            rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
            last_h_data, rnn.cy_desc(), last_c_data,
            workspace_data_.data<uint8_t>(), workspace_size));
      } else {
#if CUDNN_VERSION >= 7201
        // for inference
        // This interface is used when the input/output is padded.
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnRNNForwardInferenceEx(
                handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
                init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
                rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
                rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
                nullptr, nullptr, nullptr, nullptr,
                workspace_data_.data<uint8_t>(), workspace_size));
#else
        PADDLE_ENFORCE_NOT_NULL(
            nullptr, platform::errors::Unavailable(
                         "The padded input is supported by "
                         "cudnnRNNForwardInferenceEx, but it only works when "
                         "the version of cudnn is larger than 7.2.1"));
#endif
      }
L
liuhongyu 已提交
116
    } else {
G
GaoWei8 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
      if (sequence_length.empty()) {
        // for train
        // This interface is used when the input/output is unpadded.
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
            handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), x_data,
            rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
            rnn.w_desc(), w_data, rnn.y_desc(), out_data, rnn.hy_desc(),
            last_h_data, rnn.cy_desc(), last_c_data,
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
      } else {
#if CUDNN_VERSION >= 7201
        // for train
        // This interface is used when the input/output is padded.
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnRNNForwardTrainingEx(
                handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.hx_desc(),
                init_h_data, rnn.cx_desc(), init_c_data, rnn.w_desc(), w_data,
                rnn.y_seq_desc(), out_data, rnn.hy_desc(), last_h_data,
                rnn.cy_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
                nullptr, nullptr, nullptr, nullptr,
                workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
                reserve_size));
#else
        PADDLE_ENFORCE_NOT_NULL(
            nullptr, platform::errors::Unavailable(
                         "The padded input is supported by "
                         "cudnnRNNForwardTrainingEx, but it only works when "
                         "the version of cudnn is larger than 7.2.1"));
#endif
      }
L
liuhongyu 已提交
148 149 150 151
    }
  }
};

C
chengduozh 已提交
152
template <typename T>
L
liuhongyu 已提交
153 154 155 156 157 158 159
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *input = ctx.Input<Tensor>("Input");
    auto *weight = ctx.Input<Tensor>("W");
    auto *init_h = ctx.Input<Tensor>("InitH");
    auto *init_c = ctx.Input<Tensor>("InitC");
G
GaoWei8 已提交
160 161 162
    auto *reserve = ctx.Input<Tensor>("Reserve");
    auto *state_out = ctx.Input<Tensor>("StateOut");

L
liuhongyu 已提交
163 164
    auto *out = ctx.Input<Tensor>("Out");
    auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
G
GaoWei8 已提交
165 166
    auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
    auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
L
liuhongyu 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179

    auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
    auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
    auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

    auto input_dims = input->dims();
    auto init_h_dims = init_h->dims();
    auto init_c_dims = init_c->dims();

G
GaoWei8 已提交
180 181 182 183 184 185 186
    auto *weight_data = weight->data<T>();
    auto *init_h_data = init_h->data<T>();
    auto *init_c_data = init_c->data<T>();
    auto *out_data = out->data<T>();
    auto *out_grad_data = out_grad->data<T>();
    auto *last_h_grad_data = last_h_grad->data<T>();
    auto *last_c_grad_data = last_c_grad->data<T>();
L
liuhongyu 已提交
187

G
GaoWei8 已提交
188 189 190
    math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
    weight_grad->mutable_data<T>(ctx.GetPlace());
    zero(dev_ctx, weight_grad, static_cast<T>(0.0));
L
liuhongyu 已提交
191

G
GaoWei8 已提交
192 193
    in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
    auto *in_grad_data = in_grad->data<T>();
L
liuhongyu 已提交
194

G
GaoWei8 已提交
195 196
    init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
    auto *init_h_grad_data = init_h_grad->data<T>();
L
liuhongyu 已提交
197

G
GaoWei8 已提交
198 199
    init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
    auto *init_c_grad_data = init_c_grad->data<T>();
L
liuhongyu 已提交
200

G
GaoWei8 已提交
201 202 203 204 205
    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    int seed = ctx.Attr<int>("seed");
G
GaoWei8 已提交
206
    auto sequence_length = ctx.Attr<std::vector<int>>("sequence_length");
G
GaoWei8 已提交
207

G
GaoWei8 已提交
208 209 210 211
    int seq_length = input_dims[0];
    int batch_size = input->dims()[1];
    int input_size = input->dims()[2];
    int weight_numel = weight->numel();
G
GaoWei8 已提交
212

G
GaoWei8 已提交
213
    size_t workspace_size;
G
GaoWei8 已提交
214
    size_t reserve_size;
G
GaoWei8 已提交
215 216 217 218 219 220 221 222 223 224 225

    platform::ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                                num_layers, dropout_prob, seed, weight_numel,
                                true, is_bidirec);

    rnn.Create<T>(handle, ctx.GetPlace(), sequence_length, &workspace_size,
                  &reserve_size, const_cast<Tensor *>(state_out));

    framework::Tensor workspace_data_;
    workspace_data_.Resize({static_cast<int64_t>(workspace_size)});
    workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
G
GaoWei8 已提交
226
    const uint8_t *reserve_data = reserve->data<uint8_t>();
L
liuhongyu 已提交
227

G
GaoWei8 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    if (sequence_length.empty()) {
      // This interface is used when the input/output is unpadded.
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
          handle, rnn.rnn_desc(), seq_length, rnn.y_desc(), out_data,
          rnn.y_desc(), out_grad_data, rnn.hy_desc(), last_h_grad_data,
          rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
          rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data, rnn.x_desc(),
          in_grad_data, rnn.hx_desc(), init_h_grad_data, rnn.cx_desc(),
          init_c_grad_data, workspace_data_.data<uint8_t>(), workspace_size,
          const_cast<uint8_t *>(reserve_data), reserve_size));

      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
          handle, rnn.rnn_desc(), seq_length, rnn.x_desc(), input->data<T>(),
          rnn.hx_desc(), init_h->data<T>(), rnn.y_desc(), out->data<T>(),
          workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
          weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
          reserve_size));
    } else {
#if CUDNN_VERSION >= 7201
      // for train
      // This interface is used when the input/output is padded.
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
          handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
          out_grad_data, nullptr, nullptr, rnn.hy_desc(), last_h_grad_data,
          rnn.cy_desc(), last_c_grad_data, rnn.w_desc(), weight_data,
          rnn.hx_desc(), init_h_data, rnn.cx_desc(), init_c_data,
          rnn.x_seq_desc(), in_grad_data, rnn.hx_desc(), init_h_grad_data,
          rnn.cx_desc(), init_c_grad_data, nullptr, nullptr,
          workspace_data_.data<uint8_t>(), workspace_size,
          const_cast<uint8_t *>(reserve_data), reserve_size));

      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
          handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
          rnn.hx_desc(), init_h->data<T>(), rnn.y_seq_desc(), out->data<T>(),
          workspace_data_.data<uint8_t>(), workspace_size, rnn.w_desc(),
          weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
          reserve_size));
#else
      PADDLE_ENFORCE_NOT_NULL(
          nullptr,
          platform::errors::Unavailable(
              "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
              "cudnnRNNBackwardWeightsEx, but it only works when the version "
              "of cudnn is larger than 7.2.1"));
#endif
    }
L
liuhongyu 已提交
274 275 276 277 278 279 280
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
G
GaoWei8 已提交
281 282 283 284
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
                        ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
                        ops::CudnnLSTMGPUGradKernel<double>);