gru_compute.cu 8.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yu Yang 已提交
12
#include <paddle/fluid/platform/device_context.h>
Y
Yu Yang 已提交
13
#include "paddle/fluid/operators/math/blas.h"
Y
Yi Wang 已提交
14 15 16
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h"
G
guosheng 已提交
17 18 19 20 21 22

namespace paddle {
namespace operators {
namespace math {

template <typename T>
Q
QI JUN 已提交
23 24
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
  static void compute(const platform::CUDADeviceContext &context,
25
                      GRUMetaValue<T> value, int frame_size, int batch_size,
26
                      const detail::ActivationType active_node,
Q
Qiao Longfei 已提交
27 28
                      const detail::ActivationType active_gate,
                      bool origin_mode) {
Q
QI JUN 已提交
29
    auto stream = context.stream();
G
guosheng 已提交
30 31
    dim3 threads;
    dim3 grid;
G
guosheng 已提交
32
    if (batch_size == 1) {
33 34 35
      constexpr int tiled_size = 16;
      int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
      threads = dim3(tiled_size, 1);
G
guosheng 已提交
36
      grid = dim3(frame_blocks, 1);
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

      detail::KeFastCollectiveGruGate<T,
                                      tiled_size><<<grid, threads, 0, stream>>>(
          value.gate_value, value.prev_out_value, value.gate_weight,
          value.reset_output_value, frame_size, active_gate);

      frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
      grid = dim3(frame_blocks, 1);
      detail::KeFastCollectiveGruOut<T,
                                     tiled_size><<<grid, threads, 0, stream>>>(
          value.state_weight, value.prev_out_value, value.output_value,
          value.gate_value, value.reset_output_value, frame_size, active_node,
          origin_mode);

      return;
G
guosheng 已提交
52 53
    } else {
      threads = dim3(32, 32);
G
guosheng 已提交
54
      grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
G
guosheng 已提交
55
    }
Y
Yu Yang 已提交
56
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
G
guosheng 已提交
57
    if (value.prev_out_value) {
Y
Yu Yang 已提交
58 59 60
      blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
                value.prev_out_value, frame_size, value.gate_weight,
                frame_size * 2, 1, value.gate_value, frame_size * 3);
G
guosheng 已提交
61 62
    }

G
guosheng 已提交
63
    if (batch_size == 1) {
G
guosheng 已提交
64
      detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
G
guosheng 已提交
65
                                      /* is_batch= */ false,
G
guosheng 已提交
66
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
67 68 69
          detail::forward::gru_resetOutput<T>(), value.gate_value,
          value.reset_output_value, value.prev_out_value, frame_size,
          batch_size, active_gate);
G
guosheng 已提交
70 71
    } else {
      detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
G
guosheng 已提交
72
                                      /* is_batch= */ true,
G
guosheng 已提交
73
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
74 75 76
          detail::forward::gru_resetOutput<T>(), value.gate_value,
          value.reset_output_value, value.prev_out_value, frame_size,
          batch_size, active_gate);
G
guosheng 已提交
77 78
    }

G
guosheng 已提交
79
    if (value.prev_out_value) {
Y
Yu Yang 已提交
80 81 82 83
      blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
                value.reset_output_value, frame_size, value.state_weight,
                frame_size, 1, value.gate_value + frame_size * 2,
                frame_size * 3);
G
guosheng 已提交
84 85
    }

G
guosheng 已提交
86
    if (batch_size == 1) {
G
guosheng 已提交
87
      detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
G
guosheng 已提交
88
                                      /* is_batch= */ false,
G
guosheng 已提交
89
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
90 91
          detail::forward::gru_finalOutput<T>(), value.gate_value,
          value.prev_out_value, value.output_value, frame_size, batch_size,
Q
Qiao Longfei 已提交
92
          active_node, origin_mode);
G
guosheng 已提交
93 94
    } else {
      detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
G
guosheng 已提交
95
                                      /* is_batch= */ true,
G
guosheng 已提交
96
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
97 98
          detail::forward::gru_finalOutput<T>(), value.gate_value,
          value.prev_out_value, value.output_value, frame_size, batch_size,
Q
Qiao Longfei 已提交
99
          active_node, origin_mode);
G
guosheng 已提交
100 101 102 103 104
    }
  }
};

template <typename T>
Q
QI JUN 已提交
105 106
struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
  static void compute(const platform::CUDADeviceContext &context,
107
                      GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
108
                      int frame_size, int batch_size,
109
                      const detail::ActivationType active_node,
110 111
                      const detail::ActivationType active_gate,
                      bool origin_mode) {
Q
QI JUN 已提交
112
    auto stream = context.stream();
G
guosheng 已提交
113 114
    dim3 threads;
    dim3 grid;
G
guosheng 已提交
115 116 117 118 119
    if (batch_size == 1) {
      int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
      int frame_blocks = (frame_size + 1024 - 1) / 1024;
      threads = dim3(frame_per_block, 1);
      grid = dim3(frame_blocks, 1);
G
guosheng 已提交
120 121
    } else {
      threads = dim3(32, 32);
G
guosheng 已提交
122
      grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
G
guosheng 已提交
123 124
    }

G
guosheng 已提交
125
    if (batch_size == 1) {
G
guosheng 已提交
126 127
      detail::KeGruBackwardStateGrad<
          detail::backward::gru_stateGrad<T>,
G
guosheng 已提交
128 129 130
          /* is_batch= */ false><<<grid, threads, 0, stream>>>(
          detail::backward::gru_stateGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
131
          grad.output_grad, frame_size, batch_size, active_node, origin_mode);
G
guosheng 已提交
132 133 134
    } else {
      detail::KeGruBackwardStateGrad<
          detail::backward::gru_stateGrad<T>,
G
guosheng 已提交
135 136 137
          /* is_batch= */ true><<<grid, threads, 0, stream>>>(
          detail::backward::gru_stateGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
138
          grad.output_grad, frame_size, batch_size, active_node, origin_mode);
G
guosheng 已提交
139 140
    }

Y
Yu Yang 已提交
141 142
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);

G
guosheng 已提交
143
    if (value.prev_out_value && grad.prev_out_grad) {
Y
Yu Yang 已提交
144 145 146 147
      blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
                grad.gate_grad + frame_size * 2, frame_size * 3,
                value.state_weight, frame_size, 0, grad.reset_output_grad,
                frame_size);
G
guosheng 已提交
148

G
guosheng 已提交
149
      if (grad.state_weight_grad) {
Y
Yu Yang 已提交
150 151 152 153
        blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
                  value.reset_output_value, frame_size,
                  grad.gate_grad + frame_size * 2, frame_size * 3, 1,
                  grad.state_weight_grad, frame_size);
G
guosheng 已提交
154 155 156
      }
    }

G
guosheng 已提交
157
    if (batch_size == 1) {
G
guosheng 已提交
158 159
      detail::KeGruBackwardResetGrad<
          detail::backward::gru_resetGrad<T>,
G
guosheng 已提交
160 161 162 163
          /* is_batch= */ false><<<grid, threads, 0, stream>>>(
          detail::backward::gru_resetGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.reset_output_grad, frame_size, batch_size, active_gate);
G
guosheng 已提交
164 165 166
    } else {
      detail::KeGruBackwardResetGrad<
          detail::backward::gru_resetGrad<T>,
G
guosheng 已提交
167 168 169 170
          /* is_batch= */ true><<<grid, threads, 0, stream>>>(
          detail::backward::gru_resetGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.reset_output_grad, frame_size, batch_size, active_gate);
G
guosheng 已提交
171 172
    }

G
guosheng 已提交
173
    if (grad.prev_out_grad && value.prev_out_value) {
Y
Yu Yang 已提交
174 175 176
      blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
                grad.gate_grad, frame_size * 3, value.gate_weight,
                frame_size * 2, 1, grad.prev_out_grad, frame_size);
G
guosheng 已提交
177

G
guosheng 已提交
178
      if (grad.gate_weight_grad) {
Y
Yu Yang 已提交
179 180 181
        blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
                  value.prev_out_value, frame_size, grad.gate_grad,
                  frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2);
G
guosheng 已提交
182 183 184 185 186
      }
    }
  }
};

Q
QI JUN 已提交
187 188 189 190
template struct GRUUnitFunctor<platform::CUDADeviceContext, float>;
template struct GRUUnitFunctor<platform::CUDADeviceContext, double>;
template struct GRUUnitGradFunctor<platform::CUDADeviceContext, float>;
template struct GRUUnitGradFunctor<platform::CUDADeviceContext, double>;
G
guosheng 已提交
191 192 193

}  // namespace math
}  // namespace operators
G
guosheng 已提交
194
}  // namespace paddle