gru_cpu_kernel.h 15.0 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
17
#include "paddle/operators/math/detail/activation_functions.h"
G
guosheng 已提交
18 19 20 21 22 23 24 25 26 27
#include "paddle/operators/math/gru_compute.h"

namespace paddle {
namespace operators {
namespace math {
namespace detail {

#ifndef __NVCC__

template <class OpResetOutput, typename T>
G
guosheng 已提交
28 29 30
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
                                       T *gate_value, T *reset_output_value,
                                       T *prev_output_value, int frame_size,
31
                                       ActivationType active_gate) {
G
guosheng 已提交
32 33 34 35 36 37 38 39 40 41 42 43
  T r_value_update_gate;
  T r_value_reset_gate;
  T r_value_reset_output;
  T r_prev_out = 0;
  T *update_gate = gate_value;
  T *reset_gate = gate_value + frame_size;

  for (int i = 0; i < frame_size; i++) {
    r_value_update_gate = update_gate[i];
    r_value_reset_gate = reset_gate[i];
    if (prev_output_value) {
      r_prev_out = prev_output_value[i];
G
guosheng 已提交
44 45
    }

G
guosheng 已提交
46 47
    op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
                    r_value_reset_output, active_gate);
G
guosheng 已提交
48

G
guosheng 已提交
49 50 51
    update_gate[i] = r_value_update_gate;
    reset_gate[i] = r_value_reset_gate;
    reset_output_value[i] = r_value_reset_output;
G
guosheng 已提交
52 53 54 55
  }
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
56 57 58
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
                                       T *gate_value, T *prev_output_value,
                                       T *output_value, int frame_size,
59
                                       ActivationType active_node) {
G
guosheng 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  T r_value_update_gate;
  T r_value_frame_state;
  T r_prev_out = 0;
  T r_output;
  T *update_gate = gate_value;
  T *frame_state = gate_value + frame_size * 2;

  for (int i = 0; i < frame_size; i++) {
    r_value_update_gate = update_gate[i];
    r_value_frame_state = frame_state[i];
    if (prev_output_value) {
      r_prev_out = prev_output_value[i];
G
guosheng 已提交
72 73
    }

G
guosheng 已提交
74 75
    op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
                    r_output, active_node);
G
guosheng 已提交
76

G
guosheng 已提交
77 78
    frame_state[i] = r_value_frame_state;
    output_value[i] = r_output;
G
guosheng 已提交
79 80 81 82
  }
}

template <class OpResetOutput, typename T>
G
guosheng 已提交
83 84 85
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
                                     T *gate_value, T *reset_output_value,
                                     T *prev_output_value, int frame_size,
86
                                     ActivationType active_gate) {
G
guosheng 已提交
87
#ifdef __AVX__
G
guosheng 已提交
88 89 90 91 92 93 94 95 96 97 98 99
  __m256 r_value_update_gate;
  __m256 r_value_reset_gate;
  __m256 r_value_reset_output;
  __m256 r_prev_out = _mm256_set1_ps(0.0f);
  __m256 *update_gate = (__m256 *)gate_value;
  __m256 *reset_gate = (__m256 *)(gate_value + frame_size);

  for (int i = 0; i < frame_size / 8; i++) {
    r_value_update_gate = update_gate[i];
    r_value_reset_gate = reset_gate[i];
    if (prev_output_value) {
      r_prev_out = ((__m256 *)prev_output_value)[i];
G
guosheng 已提交
100 101
    }

G
guosheng 已提交
102 103
    op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
                    r_value_reset_output, active_gate);
G
guosheng 已提交
104

G
guosheng 已提交
105 106 107
    update_gate[i] = r_value_update_gate;
    reset_gate[i] = r_value_reset_gate;
    ((__m256 *)reset_output_value)[i] = r_value_reset_output;
G
guosheng 已提交
108 109 110 111 112
  }
#endif
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
113 114 115
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
                                     T *gate_value, T *prev_output_value,
                                     T *output_value, int frame_size,
116
                                     ActivationType active_node) {
G
guosheng 已提交
117
#ifdef __AVX__
G
guosheng 已提交
118 119 120 121 122 123 124 125 126 127 128 129
  __m256 r_value_update_gate;
  __m256 r_value_frame_state;
  __m256 r_prev_out = _mm256_set1_ps(0.0f);
  __m256 r_output;
  __m256 *update_gate = (__m256 *)gate_value;
  __m256 *frame_state = (__m256 *)(gate_value + frame_size * 2);

  for (int i = 0; i < frame_size / 8; i++) {
    r_value_update_gate = update_gate[i];
    r_value_frame_state = frame_state[i];
    if (prev_output_value) {
      r_prev_out = ((__m256 *)prev_output_value)[i];
G
guosheng 已提交
130 131
    }

G
guosheng 已提交
132 133
    op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
                    r_output, active_node);
G
guosheng 已提交
134

G
guosheng 已提交
135 136
    frame_state[i] = r_value_frame_state;
    ((__m256 *)output_value)[i] = r_output;
G
guosheng 已提交
137 138 139 140 141
  }
#endif
}

template <class OpResetOutput, typename T>
G
guosheng 已提交
142
inline void forward_reset_output(OpResetOutput op_reset_output,
143 144
                                 GRUMetaValue<T> value, int frame_size,
                                 int batch_size, ActivationType active_gate) {
G
guosheng 已提交
145 146
  for (int b = 0; b < batch_size; b++) {
    if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
G
guosheng 已提交
147
      hl_avx_gru_forward_reset_output(
G
guosheng 已提交
148 149
          op_reset_output, value.gate_value, value.reset_output_value,
          value.prev_out_value, frame_size, active_gate);
G
guosheng 已提交
150 151
    } else {
      hl_naive_gru_forward_reset_output(
G
guosheng 已提交
152 153
          op_reset_output, value.gate_value, value.reset_output_value,
          value.prev_out_value, frame_size, active_gate);
G
guosheng 已提交
154 155
    }

G
guosheng 已提交
156 157 158 159
    value.gate_value += frame_size * 3;
    value.reset_output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
160 161 162 163 164
    }
  }
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
165
inline void forward_final_output(OpFinalOutput op_final_output,
166 167
                                 GRUMetaValue<T> value, int frame_size,
                                 int batch_size, ActivationType active_node) {
G
guosheng 已提交
168 169 170 171 172
  for (int b = 0; b < batch_size; b++) {
    if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
      hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
                                      value.prev_out_value, value.output_value,
                                      frame_size, active_node);
G
guosheng 已提交
173
    } else {
G
guosheng 已提交
174 175 176
      hl_naive_gru_forward_final_output(
          op_final_output, value.gate_value, value.prev_out_value,
          value.output_value, frame_size, active_node);
G
guosheng 已提交
177 178
    }

G
guosheng 已提交
179 180 181 182
    value.gate_value += frame_size * 3;
    value.output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
183 184 185 186 187
    }
  }
}

template <class OpStateGrad, typename T>
G
guosheng 已提交
188 189 190 191
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
                                      T *gate_grad, T *prev_out_value,
                                      T *prev_out_grad, T *output_grad,
                                      int frame_size,
192
                                      ActivationType active_node) {
G
guosheng 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
  T r_update_gate_value;
  T r_update_gate_grad;
  T r_frame_state_value;
  T r_frame_state_grad;
  T r_out_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T *update_gate_value = gate_value;
  T *update_gate_grad = gate_grad;
  T *frame_state_value = gate_value + frame_size * 2;
  T *frame_state_grad = gate_grad + frame_size * 2;

  for (int i = 0; i < frame_size; i++) {
    r_update_gate_value = update_gate_value[i];
    r_frame_state_value = frame_state_value[i];
    r_out_grad = output_grad[i];
    if (prev_out_value) {
      r_prev_out_value = prev_out_value[i];
G
guosheng 已提交
211
    }
G
guosheng 已提交
212 213
    if (prev_out_grad) {
      r_prev_out_grad = prev_out_grad[i];
G
guosheng 已提交
214 215
    }

G
guosheng 已提交
216 217 218
    op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
                  r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
                  r_out_grad, active_node);
G
guosheng 已提交
219

G
guosheng 已提交
220 221 222 223
    update_gate_grad[i] = r_update_gate_grad;
    frame_state_grad[i] = r_frame_state_grad;
    if (prev_out_grad) {
      prev_out_grad[i] = r_prev_out_grad;
G
guosheng 已提交
224 225 226 227 228
    }
  }
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
229 230 231 232
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
                                      T *gate_grad, T *prev_out_value,
                                      T *prev_out_grad, T *reset_output_grad,
                                      int frame_size,
233
                                      ActivationType active_gate) {
G
guosheng 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  T r_update_gate_value;
  T r_update_gate_grad;
  T r_reset_gate_value;
  T r_reset_gate_grad;
  T r_reset_output_grad = 0;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T *update_gate_value = gate_value;
  T *update_gate_grad = gate_grad;
  T *reset_gate_value = gate_value + frame_size;
  T *reset_gate_grad = gate_grad + frame_size;

  for (int i = 0; i < frame_size; i++) {
    r_update_gate_value = update_gate_value[i];
    r_update_gate_grad = update_gate_grad[i];
    r_reset_gate_value = reset_gate_value[i];

    if (prev_out_value && prev_out_grad) {
      r_reset_output_grad = reset_output_grad[i];
G
guosheng 已提交
253
    }
G
guosheng 已提交
254 255
    if (prev_out_value) {
      r_prev_out_value = prev_out_value[i];
G
guosheng 已提交
256
    }
G
guosheng 已提交
257 258
    if (prev_out_grad) {
      r_prev_out_grad = prev_out_grad[i];
G
guosheng 已提交
259 260
    }

G
guosheng 已提交
261 262 263
    op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
                  r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
                  r_reset_output_grad, active_gate);
G
guosheng 已提交
264

G
guosheng 已提交
265 266 267 268
    update_gate_grad[i] = r_update_gate_grad;
    reset_gate_grad[i] = r_reset_gate_grad;
    if (prev_out_grad) {
      prev_out_grad[i] = r_prev_out_grad;
G
guosheng 已提交
269 270 271 272 273
    }
  }
}

template <class OpStateGrad, typename T>
G
guosheng 已提交
274 275 276 277
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
                                    T *gate_grad, T *prev_out_value,
                                    T *prev_out_grad, T *output_grad,
                                    int frame_size,
278
                                    ActivationType active_node) {
G
guosheng 已提交
279
#ifdef __AVX__
G
guosheng 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
  __m256 r_update_gate_value;
  __m256 r_update_gate_grad;
  __m256 r_frame_state_value;
  __m256 r_frame_state_grad;
  __m256 r_out_grad;
  __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
  __m256 *update_gate_value = (__m256 *)gate_value;
  __m256 *update_gate_grad = (__m256 *)gate_grad;
  __m256 *frame_state_value = (__m256 *)(gate_value + frame_size * 2);
  __m256 *frame_state_grad = (__m256 *)(gate_grad + frame_size * 2);

  for (int i = 0; i < frame_size / 8; i++) {
    r_update_gate_value = update_gate_value[i];
    r_frame_state_value = frame_state_value[i];
    r_out_grad = ((__m256 *)output_grad)[i];
    if (prev_out_value) {
      r_prev_out_value = ((__m256 *)prev_out_value)[i];
G
guosheng 已提交
298
    }
G
guosheng 已提交
299 300
    if (prev_out_grad) {
      r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
G
guosheng 已提交
301 302
    }

G
guosheng 已提交
303 304 305
    op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
                  r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
                  r_out_grad, active_node);
G
guosheng 已提交
306

G
guosheng 已提交
307 308 309 310
    update_gate_grad[i] = r_update_gate_grad;
    frame_state_grad[i] = r_frame_state_grad;
    if (prev_out_grad) {
      ((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
G
guosheng 已提交
311 312 313 314 315 316
    }
  }
#endif
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
317 318 319 320
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
                                    T *gate_grad, T *prev_out_value,
                                    T *prev_out_grad, T *reset_output_grad,
                                    int frame_size,
321
                                    ActivationType active_gate) {
G
guosheng 已提交
322
#ifdef __AVX__
G
guosheng 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
  __m256 r_update_gate_value;
  __m256 r_update_gate_grad;
  __m256 r_reset_gate_value;
  __m256 r_reset_gate_grad;
  __m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
  __m256 *update_gate_value = (__m256 *)gate_value;
  __m256 *update_gate_grad = (__m256 *)gate_grad;
  __m256 *reset_gate_value = (__m256 *)(gate_value + frame_size);
  __m256 *reset_gate_grad = (__m256 *)(gate_grad + frame_size);

  for (int i = 0; i < frame_size / 8; i++) {
    r_update_gate_value = update_gate_value[i];
    r_update_gate_grad = update_gate_grad[i];
    r_reset_gate_value = reset_gate_value[i];

    if (prev_out_value && prev_out_grad) {
      r_reset_output_grad = ((__m256 *)reset_output_grad)[i];
G
guosheng 已提交
342
    }
G
guosheng 已提交
343 344
    if (prev_out_value) {
      r_prev_out_value = ((__m256 *)prev_out_value)[i];
G
guosheng 已提交
345
    }
G
guosheng 已提交
346 347
    if (prev_out_grad) {
      r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
G
guosheng 已提交
348 349
    }

G
guosheng 已提交
350 351 352
    op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
                  r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
                  r_reset_output_grad, active_gate);
G
guosheng 已提交
353

G
guosheng 已提交
354 355 356 357
    update_gate_grad[i] = r_update_gate_grad;
    reset_gate_grad[i] = r_reset_gate_grad;
    if (prev_out_grad) {
      ((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
G
guosheng 已提交
358 359 360 361 362 363
    }
  }
#endif
}

template <class OpStateGrad, typename T>
G
guosheng 已提交
364
inline void backward_state_grad(OpStateGrad op_state_grad,
365
                                GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
366
                                int frame_size, int batch_size,
367
                                ActivationType active_node) {
G
guosheng 已提交
368 369
  for (int b = 0; b < batch_size; b++) {
    if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
G
guosheng 已提交
370
      hl_avx_gru_backward_state_grad(
G
guosheng 已提交
371 372
          op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.output_grad, frame_size, active_node);
G
guosheng 已提交
373 374
    } else {
      hl_naive_gru_backward_state_grad(
G
guosheng 已提交
375 376
          op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.output_grad, frame_size, active_node);
G
guosheng 已提交
377 378
    }

G
guosheng 已提交
379 380 381
    value.gate_value += frame_size * 3;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
382 383
    }

G
guosheng 已提交
384 385 386 387
    grad.gate_grad += frame_size * 3;
    grad.output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
G
guosheng 已提交
388 389 390 391 392
    }
  }
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
393
inline void backward_reset_grad(OpResetGrad op_reset_grad,
394
                                GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
395
                                int frame_size, int batch_size,
396
                                ActivationType active_gate) {
G
guosheng 已提交
397 398
  for (int b = 0; b < batch_size; b++) {
    if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
G
guosheng 已提交
399
      hl_avx_gru_backward_reset_grad(
G
guosheng 已提交
400 401
          op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
G
guosheng 已提交
402 403
    } else {
      hl_naive_gru_backward_reset_grad(
G
guosheng 已提交
404 405
          op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
G
guosheng 已提交
406 407
    }

G
guosheng 已提交
408 409 410
    value.gate_value += frame_size * 3;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
411 412
    }

G
guosheng 已提交
413 414 415 416
    grad.gate_grad += frame_size * 3;
    grad.reset_output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
G
guosheng 已提交
417 418 419 420 421 422 423 424 425 426
    }
  }
}

#endif

}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle