TrainingAlgorithmOp.h 3.6 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
/**
 * TrainingAlgorithmOp.h
 *
 * Author: hedaoyuan (hedaoyuan@baidu.com)
 * Created on: 2016-06-29
 *
 * Copyright (c) Baidu.com, Inc. All Rights Reserved
 *
 */

#pragma once

#include "paddle/utils/Logging.h"
#include "BaseMatrix.h"

namespace paddle {

/**
 * \brief Sparse Momentum optimizer.
 */
extern void sparseMomentumApply(BaseMatrix& value,
                                BaseMatrix& grad,
                                BaseMatrix& momU,
                                BaseMatrix& momV,
                                real alpha,
                                real beta,
                                real gamma,
                                real tau,
                                real learningRate);

/**
 * \brief AdaDelta optimizer.
 */
extern void adadeltaApply(BaseMatrix& value,
                          BaseMatrix& grad,
                          BaseMatrix& sum,
                          BaseMatrix& sum1,
                          BaseMatrix& mom,
                          BaseMatrix& lr,
                          real rou,
                          real epsilon,
                          real learningRate,
                          real momentum,
                          real decayRate);

/**
 * \brief AdaGrad optimizer.
 */
extern void adagradApply(BaseMatrix& value,
                         BaseMatrix& grad,
                         BaseMatrix& sum,
                         BaseMatrix& sum1,
                         BaseMatrix& mom,
                         BaseMatrix& lr,
                         real epsilon,
                         real learningRate,
                         real momentum,
                         real decayRate);

/**
 * \brief RMSProp optimizer.
 */
extern void rmspropApply(BaseMatrix& value,
                         BaseMatrix& grad,
                         BaseMatrix& g,
                         BaseMatrix& f,
                         BaseMatrix& mom,
                         BaseMatrix& lr,
                         real accumulatedRou,
                         real rou,
                         real epsilon,
                         real learningRate,
                         real momentum,
                         real decayRate,
                         bool firstTime);

/**
 * \brief Decayed AdaGrad optimizer.
 */
extern void decayedAdagradApply(BaseMatrix& value,
                                BaseMatrix& grad,
                                BaseMatrix& mom,
                                BaseMatrix& accum,
                                BaseMatrix& lr,
                                real accumulatedRou,
                                real rou,
                                real epsilon,
                                real learningRate,
                                real momentum,
                                real decayRate,
                                bool firstTime);

/**
 * \brief Adam optimizer.
 */
extern void adamApply(BaseMatrix& value,
                      BaseMatrix& grad,
                      BaseMatrix& mom,
                      BaseMatrix& v,
                      real beta1,
                      real beta2,
                      real beta1_power,
                      real beta2_power,
                      real epsilon,
                      real learningRate);

/**
 * \brief AdaMax optimizer.
 */
extern void adamaxApply(BaseMatrix& value,
                        BaseMatrix& grad,
                        BaseMatrix& mom,  // firse moment
                        BaseMatrix& u,    // weighted infinity norm
                        real beta1,
                        real beta2,
                        int64_t step,
                        real alpha);

}  // namespace paddle