gru_compute.cu 7.5 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
Q
QI JUN 已提交
22 23
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
  static void compute(const platform::CUDADeviceContext &context,
24 25
                      GRUMetaValue<T> value, int frame_size, int batch_size,
                      ActivationType active_node, ActivationType active_gate) {
Q
QI JUN 已提交
26
    auto stream = context.stream();
G
guosheng 已提交
27 28
    dim3 threads;
    dim3 grid;
G
guosheng 已提交
29 30 31 32 33
    if (batch_size == 1) {
      int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
      int frame_blocks = (frame_size + 1024 - 1) / 1024;
      threads = dim3(frame_per_block, 1);
      grid = dim3(frame_blocks, 1);
G
guosheng 已提交
34 35
    } else {
      threads = dim3(32, 32);
G
guosheng 已提交
36
      grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
G
guosheng 已提交
37 38
    }

G
guosheng 已提交
39
    if (value.prev_out_value) {
Q
QI JUN 已提交
40
      math::gemm<platform::CUDADeviceContext, T>(
G
guosheng 已提交
41 42 43
          context, false, false, batch_size, frame_size * 2, frame_size, 1,
          value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
          1, value.gate_value, frame_size * 3);
G
guosheng 已提交
44 45
    }

G
guosheng 已提交
46
    if (batch_size == 1) {
G
guosheng 已提交
47
      detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
G
guosheng 已提交
48
                                      /* is_batch= */ false,
G
guosheng 已提交
49
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
50 51 52
          detail::forward::gru_resetOutput<T>(), value.gate_value,
          value.reset_output_value, value.prev_out_value, frame_size,
          batch_size, active_gate);
G
guosheng 已提交
53 54
    } else {
      detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
G
guosheng 已提交
55
                                      /* is_batch= */ true,
G
guosheng 已提交
56
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
57 58 59
          detail::forward::gru_resetOutput<T>(), value.gate_value,
          value.reset_output_value, value.prev_out_value, frame_size,
          batch_size, active_gate);
G
guosheng 已提交
60 61
    }

G
guosheng 已提交
62
    if (value.prev_out_value) {
Q
QI JUN 已提交
63
      math::gemm<platform::CUDADeviceContext, T>(
G
guosheng 已提交
64 65 66
          context, false, false, batch_size, frame_size, frame_size, 1,
          value.reset_output_value, frame_size, value.state_weight, frame_size,
          1, value.gate_value + frame_size * 2, frame_size * 3);
G
guosheng 已提交
67 68
    }

G
guosheng 已提交
69
    if (batch_size == 1) {
G
guosheng 已提交
70
      detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
G
guosheng 已提交
71
                                      /* is_batch= */ false,
G
guosheng 已提交
72
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
73 74
          detail::forward::gru_finalOutput<T>(), value.gate_value,
          value.prev_out_value, value.output_value, frame_size, batch_size,
G
guosheng 已提交
75 76 77
          active_node);
    } else {
      detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
G
guosheng 已提交
78
                                      /* is_batch= */ true,
G
guosheng 已提交
79
                                      T><<<grid, threads, 0, stream>>>(
G
guosheng 已提交
80 81
          detail::forward::gru_finalOutput<T>(), value.gate_value,
          value.prev_out_value, value.output_value, frame_size, batch_size,
G
guosheng 已提交
82 83 84 85 86 87
          active_node);
    }
  }
};

template <typename T>
Q
QI JUN 已提交
88 89
struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
  static void compute(const platform::CUDADeviceContext &context,
90
                      GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
91
                      int frame_size, int batch_size,
92
                      ActivationType active_node, ActivationType active_gate) {
Q
QI JUN 已提交
93
    auto stream = context.stream();
G
guosheng 已提交
94 95
    dim3 threads;
    dim3 grid;
G
guosheng 已提交
96 97 98 99 100
    if (batch_size == 1) {
      int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
      int frame_blocks = (frame_size + 1024 - 1) / 1024;
      threads = dim3(frame_per_block, 1);
      grid = dim3(frame_blocks, 1);
G
guosheng 已提交
101 102
    } else {
      threads = dim3(32, 32);
G
guosheng 已提交
103
      grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
G
guosheng 已提交
104 105
    }

G
guosheng 已提交
106
    if (batch_size == 1) {
G
guosheng 已提交
107 108
      detail::KeGruBackwardStateGrad<
          detail::backward::gru_stateGrad<T>,
G
guosheng 已提交
109 110 111 112
          /* is_batch= */ false><<<grid, threads, 0, stream>>>(
          detail::backward::gru_stateGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.output_grad, frame_size, batch_size, active_node);
G
guosheng 已提交
113 114 115
    } else {
      detail::KeGruBackwardStateGrad<
          detail::backward::gru_stateGrad<T>,
G
guosheng 已提交
116 117 118 119
          /* is_batch= */ true><<<grid, threads, 0, stream>>>(
          detail::backward::gru_stateGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.output_grad, frame_size, batch_size, active_node);
G
guosheng 已提交
120 121
    }

G
guosheng 已提交
122
    if (value.prev_out_value && grad.prev_out_grad) {
Q
QI JUN 已提交
123
      math::gemm<platform::CUDADeviceContext, T>(
G
guosheng 已提交
124 125 126
          context, false, true, batch_size, frame_size, frame_size, 1,
          grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
          frame_size, 0, grad.reset_output_grad, frame_size);
G
guosheng 已提交
127

G
guosheng 已提交
128
      if (grad.state_weight_grad) {
Q
QI JUN 已提交
129
        math::gemm<platform::CUDADeviceContext, T>(
G
guosheng 已提交
130 131 132 133
            context, true, false, frame_size, frame_size, batch_size, 1,
            value.reset_output_value, frame_size,
            grad.gate_grad + frame_size * 2, frame_size * 3, 1,
            grad.state_weight_grad, frame_size);
G
guosheng 已提交
134 135 136
      }
    }

G
guosheng 已提交
137
    if (batch_size == 1) {
G
guosheng 已提交
138 139
      detail::KeGruBackwardResetGrad<
          detail::backward::gru_resetGrad<T>,
G
guosheng 已提交
140 141 142 143
          /* is_batch= */ false><<<grid, threads, 0, stream>>>(
          detail::backward::gru_resetGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.reset_output_grad, frame_size, batch_size, active_gate);
G
guosheng 已提交
144 145 146
    } else {
      detail::KeGruBackwardResetGrad<
          detail::backward::gru_resetGrad<T>,
G
guosheng 已提交
147 148 149 150
          /* is_batch= */ true><<<grid, threads, 0, stream>>>(
          detail::backward::gru_resetGrad<T>(), value.gate_value,
          grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
          grad.reset_output_grad, frame_size, batch_size, active_gate);
G
guosheng 已提交
151 152
    }

G
guosheng 已提交
153
    if (grad.prev_out_grad && value.prev_out_value) {
Q
QI JUN 已提交
154
      math::gemm<platform::CUDADeviceContext, T>(
G
guosheng 已提交
155 156 157
          context, false, true, batch_size, frame_size, frame_size * 2, 1,
          grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
          grad.prev_out_grad, frame_size);
G
guosheng 已提交
158

G
guosheng 已提交
159
      if (grad.gate_weight_grad) {
Q
QI JUN 已提交
160
        math::gemm<platform::CUDADeviceContext, T>(
G
guosheng 已提交
161 162 163
            context, true, false, frame_size, frame_size * 2, batch_size, 1,
            value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
            grad.gate_weight_grad, frame_size * 2);
G
guosheng 已提交
164 165 166 167 168
      }
    }
  }
};

Q
QI JUN 已提交
169 170 171 172
template struct GRUUnitFunctor<platform::CUDADeviceContext, float>;
template struct GRUUnitFunctor<platform::CUDADeviceContext, double>;
template struct GRUUnitGradFunctor<platform::CUDADeviceContext, float>;
template struct GRUUnitGradFunctor<platform::CUDADeviceContext, double>;
G
guosheng 已提交
173 174 175

}  // namespace math
}  // namespace operators
G
guosheng 已提交
176
}  // namespace paddle