gru_compute.cu 8.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2 3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yu Yang 已提交
12
#include <paddle/fluid/platform/device_context.h>
Y
Yu Yang 已提交
13
#include "paddle/fluid/operators/math/blas.h"
Y
Yi Wang 已提交
14 15 16
#include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/gru_compute.h"
G
guosheng 已提交
17 18 19 20 21 22

namespace paddle {
namespace operators {
namespace math {

template <typename T>
Q
QI JUN 已提交
23 24
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
  static void compute(const platform::CUDADeviceContext &context,
25
                      GRUMetaValue<T> value, int frame_size, int batch_size,
26
                      const detail::ActivationType active_node,
Q
Qiao Longfei 已提交
27 28
                      const detail::ActivationType active_gate,
                      bool origin_mode) {
Q
QI JUN 已提交
29
    auto stream = context.stream();
G
guosheng 已提交
30 31
    dim3 threads;
    dim3 grid;
G
guosheng 已提交
32
    if (batch_size == 1) {
33 34 35 36 37 38 39 40 41
      if (context.GetComputeCapability() >= 70) {
        constexpr int tiled_size = 16;
        int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
        threads = dim3(tiled_size, 1);
        grid = dim3(frame_blocks, 1);
        detail::KeFastCollectiveGruGate<
            T, tiled_size><<<grid, threads, 0, stream>>>(
            value.gate_value, value.prev_out_value, value.gate_weight,
            value.reset_output_value, frame_size, active_gate);
42

43 44 45 46 47 48 49
        frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
        grid = dim3(frame_blocks, 1);
        detail::KeFastCollectiveGruOut<
            T, tiled_size><<<grid, threads, 0, stream>>>(
            value.state_weight, value.prev_out_value, value.output_value,
            value.gate_value, value.reset_output_value, frame_size, active_node,
            origin_mode);
50

51 52 53 54 55 56 57
        return;
      } else {
        int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
        int frame_blocks = (frame_size + 1024 - 1) / 1024;
        threads = dim3(frame_per_block, 1);
        grid = dim3(frame_blocks, 1);
      }
G
guosheng 已提交
58 59
    } else {
      threads = dim3(32, 32);
G
guosheng 已提交
60
      grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
G
guosheng 已提交
61
    }
Y
Yu Yang 已提交
62
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
G
guosheng 已提交
63
    if (value.prev_out_value) {
Y
Yu Yang 已提交
64 65 66
      blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1,
                value.prev_out_value, frame_size, value.gate_weight,
                frame_size * 2, 1, value.gate_value, frame_size * 3);
G
guosheng 已提交
67 68
    }

G
guosheng 已提交
69
    if (batch_size == 1) {
G
guosheng 已提交
70
      detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
G
guosheng 已提交
71
                                      /* is_batch= */ false,
G
guosheng 已提交
72
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
73 74 75
          detail::forward::gru_resetOutput<T>(), value.gate_value,
          value.reset_output_value, value.prev_out_value, frame_size,
          batch_size, active_gate);
G
guosheng 已提交
76 77
    } else {
      detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
G
guosheng 已提交
78
                                      /* is_batch= */ true,
G
guosheng 已提交
79
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
80 81 82
          detail::forward::gru_resetOutput<T>(), value.gate_value,
          value.reset_output_value, value.prev_out_value, frame_size,
          batch_size, active_gate);
G
guosheng 已提交
83 84
    }

G
guosheng 已提交
85
    if (value.prev_out_value) {
Y
Yu Yang 已提交
86 87 88 89
      blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
                value.reset_output_value, frame_size, value.state_weight,
                frame_size, 1, value.gate_value + frame_size * 2,
                frame_size * 3);
G
guosheng 已提交
90 91
    }

G
guosheng 已提交
92
    if (batch_size == 1) {
G
guosheng 已提交
93
      detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
G
guosheng 已提交
94
                                      /* is_batch= */ false,
G
guosheng 已提交
95
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
96 97
          detail::forward::gru_finalOutput<T>(), value.gate_value,
          value.prev_out_value, value.output_value, frame_size, batch_size,
Q
Qiao Longfei 已提交
98
          active_node, origin_mode);
G
guosheng 已提交
99 100
    } else {
      detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
G
guosheng 已提交
101
                                      /* is_batch= */ true,
G
guosheng 已提交
102
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
103 104
          detail::forward::gru_finalOutput<T>(), value.gate_value,
          value.prev_out_value, value.output_value, frame_size, batch_size,
Q
Qiao Longfei 已提交
105
          active_node, origin_mode);
G
guosheng 已提交
106 107 108 109 110
    }
  }
};

template <typename T>
Q
QI JUN 已提交
111 112
struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
  static void compute(const platform::CUDADeviceContext &context,
113
                      GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
114
                      int frame_size, int batch_size,
115
                      const detail::ActivationType active_node,
116 117
                      const detail::ActivationType active_gate,
                      bool origin_mode) {
Q
QI JUN 已提交
118
    auto stream = context.stream();
G
guosheng 已提交
119 120
    dim3 threads;
    dim3 grid;
G
guosheng 已提交
121 122 123 124 125
    if (batch_size == 1) {
      int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
      int frame_blocks = (frame_size + 1024 - 1) / 1024;
      threads = dim3(frame_per_block, 1);
      grid = dim3(frame_blocks, 1);
G
guosheng 已提交
126 127
    } else {
      threads = dim3(32, 32);
G
guosheng 已提交
128
      grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
G
guosheng 已提交
129 130
    }

G
guosheng 已提交
131
    if (batch_size == 1) {
G
guosheng 已提交
132 133
      detail::KeGruBackwardStateGrad<
          detail::backward::gru_stateGrad<T>,
G
guosheng 已提交
134 135 136
          /* is_batch= */ false><<<grid, threads, 0, stream>>>(
          detail::backward::gru_stateGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
137
          grad.output_grad, frame_size, batch_size, active_node, origin_mode);
G
guosheng 已提交
138 139 140
    } else {
      detail::KeGruBackwardStateGrad<
          detail::backward::gru_stateGrad<T>,
G
guosheng 已提交
141 142 143
          /* is_batch= */ true><<<grid, threads, 0, stream>>>(
          detail::backward::gru_stateGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
144
          grad.output_grad, frame_size, batch_size, active_node, origin_mode);
G
guosheng 已提交
145 146
    }

Y
Yu Yang 已提交
147 148
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);

G
guosheng 已提交
149
    if (value.prev_out_value && grad.prev_out_grad) {
Y
Yu Yang 已提交
150 151 152 153
      blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
                grad.gate_grad + frame_size * 2, frame_size * 3,
                value.state_weight, frame_size, 0, grad.reset_output_grad,
                frame_size);
G
guosheng 已提交
154

G
guosheng 已提交
155
      if (grad.state_weight_grad) {
Y
Yu Yang 已提交
156 157 158 159
        blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
                  value.reset_output_value, frame_size,
                  grad.gate_grad + frame_size * 2, frame_size * 3, 1,
                  grad.state_weight_grad, frame_size);
G
guosheng 已提交
160 161 162
      }
    }

G
guosheng 已提交
163
    if (batch_size == 1) {
G
guosheng 已提交
164 165
      detail::KeGruBackwardResetGrad<
          detail::backward::gru_resetGrad<T>,
G
guosheng 已提交
166 167 168 169
          /* is_batch= */ false><<<grid, threads, 0, stream>>>(
          detail::backward::gru_resetGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.reset_output_grad, frame_size, batch_size, active_gate);
G
guosheng 已提交
170 171 172
    } else {
      detail::KeGruBackwardResetGrad<
          detail::backward::gru_resetGrad<T>,
G
guosheng 已提交
173 174 175 176
          /* is_batch= */ true><<<grid, threads, 0, stream>>>(
          detail::backward::gru_resetGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.reset_output_grad, frame_size, batch_size, active_gate);
G
guosheng 已提交
177 178
    }

G
guosheng 已提交
179
    if (grad.prev_out_grad && value.prev_out_value) {
Y
Yu Yang 已提交
180 181 182
      blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
                grad.gate_grad, frame_size * 3, value.gate_weight,
                frame_size * 2, 1, grad.prev_out_grad, frame_size);
G
guosheng 已提交
183

G
guosheng 已提交
184
      if (grad.gate_weight_grad) {
Y
Yu Yang 已提交
185 186 187
        blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
                  value.prev_out_value, frame_size, grad.gate_grad,
                  frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2);
G
guosheng 已提交
188 189 190 191 192
      }
    }
  }
};

Q
QI JUN 已提交
193 194 195 196
template struct GRUUnitFunctor<platform::CUDADeviceContext, float>;
template struct GRUUnitFunctor<platform::CUDADeviceContext, double>;
template struct GRUUnitGradFunctor<platform::CUDADeviceContext, float>;
template struct GRUUnitGradFunctor<platform::CUDADeviceContext, double>;
G
guosheng 已提交
197 198 199

}  // namespace math
}  // namespace operators
G
guosheng 已提交
200
}  // namespace paddle