hl_lstm.h 5.7 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef HL_LSTM_H_
#define HL_LSTM_H_

#include "hl_base.h"

/**
 * @brief   Lstm sequence parallel forward.
 *
 * @param[in]   gateValue           input value.
 * @param[out]  stateValue          state value.
 * @param[out]  preOutputValue     prev output value.
 * @param[out]  outputValue         output value.
 * @param[in]   checkIg             bias.
 * @param[in]   checkFg             bias.
 * @param[in]   checkOg             bias.
 * @param[in]   weight              weight.
 * @param[in]   sequence            sequence index.
 * @param[in]   frameSize           frame size.
 * @param[in]   numSequences        number of sequences.
 * @param[in]   reversed            reverse.
 * @param[in]   active_node         active input type.
 * @param[in]   active_gate         active state type.
 * @param[in]   active_state        actvie gate type.
 *
 *
 * @note    Only support frameSize = 32 or 64.
 */
extern void hl_lstm_parallel_forward(real *gateValue,
                                     real *stateValue,
                                     real *preOutputValue,
                                     real *outputValue,
                                     real *checkIg,
                                     real *checkFg,
                                     real *checkOg,
                                     real *weight,
                                     const int *sequence,
                                     int frameSize,
                                     int numSequences,
                                     bool reversed,
                                     hl_activation_mode_t active_node,
                                     hl_activation_mode_t active_gate,
                                     hl_activation_mode_t active_state);

/**
 * @brief   Lstm sequence parallel backward data.
 *
 * @param[in]   gateValue           input value.
 * @param[out]  gateGrad            input gradient.
 * @param[in]   stateValue          state value.
 * @param[out]  stateGrad           state gradient.
 * @param[out]  preOutputValue     prev output value.
 * @param[out]  preOutputGrad      prev output gradient.
 * @param[in]   outputGrad          output gradient.
 * @param[in]   checkIg             bias.
 * @param[out]  checkIgGrad         bias gradient.
 * @param[in]   checkFg             bias.
 * @param[out]  checkFgGrad         bias gradient.
 * @param[in]   checkOg             bias.
 * @param[out]  checkOgGrad         bias gradient.
 * @param[in]   weight              weight.
 * @param[in]   sequence            sequence index.
 * @param[in]   frameSize           frame size.
 * @param[in]   numSequences        number of sequences.
 * @param[in]   reversed            reverse.
 * @param[in]   active_node         active input type.
 * @param[in]   active_gate         active state type.
 * @param[in]   active_state        actvie gate type.
 *
 *
 * @note    Only support frameSize = 32 or 64.
 */
extern void hl_lstm_parallel_backward_data(real *gateValue,
                                           real *gateGrad,
                                           real *stateValue,
                                           real *stateGrad,
                                           real *preOutputValue,
                                           real *preOutputGrad,
                                           real *outputGrad,
                                           real *checkIg,
                                           real *checkIgGrad,
                                           real *checkFg,
                                           real *checkFgGrad,
                                           real *checkOg,
                                           real *checkOgGrad,
                                           real *weight,
                                           const int *sequence,
                                           int frameSize,
                                           int numSequences,
                                           bool reversed,
                                           hl_activation_mode_t active_node,
                                           hl_activation_mode_t active_gate,
                                           hl_activation_mode_t active_state);

/**
 * @brief   Lstm sequence parallel backward weight.
 *
 * @param[out]  weightGrad          weight gradient.
 * @param[in]   outputValue         output value.
 * @param[in]   gateGrad            gate gradient.
 * @param[in]   sequence            sequence index.
 * @param[in]   frameSize           frame size.
 * @param[in]   batchSize           batch size.
 * @param[in]   numSequences        number of sequences.
 * @param[in]   reversed            reverse.
 *
 */
extern void hl_lstm_parallel_backward_weight(real *weightGrad,
                                             real *outputValue,
                                             real *gateGrad,
                                             const int *sequence,
                                             int frameSize,
                                             int batchSize,
                                             int numSequences,
                                             bool reversed);

#endif /* HL_LSTM_H_ */