opr_impl.cpp 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
#include "src/naive/lamb/opr_impl.h"
#include <cmath>
#include <functional>
#include <numeric>
#include "src/common/utils.h"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;

namespace {
using Param = megdnn::LAMBUpdate::Param;

template <typename T, typename T_ACC = float>
void update(
        _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param,
        _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
        _megdnn_tensor_out new_param, const Param& param) {
    float beta_1 = param.beta_1;
    float beta_2 = param.beta_2;
    float step = param.step;
    float lr = param.lr;
    float weight_decay = param.weight_decay;
    float eps = param.eps;
    bool bias_correction = param.bias_correction;
    bool always_adapt = param.always_adapt;

    size_t total_elem = lamb_param.layout.total_nr_elems();
    T_ACC mt, vt, bc_1, bc_2, rt, d_norm = 0;
    bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1;
    bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1;

    for (size_t i = 0; i < total_elem; i++) {
        mt = m_t.ptr<T_ACC>()[i] = beta_1 * m_t_1.ptr<T_ACC>()[i] +
                                   (1 - beta_1) * static_cast<T_ACC>(grad.ptr<T>()[i]);
        vt = v_t.ptr<T_ACC>()[i] =
                beta_2 * v_t_1.ptr<T_ACC>()[i] +
                (1 - beta_2) * std::pow(static_cast<T_ACC>(grad.ptr<T>()[i]), 2);
        rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps);
        if (weight_decay != 0) {
            rt += lamb_param.ptr<T_ACC>()[i] * weight_decay;
        }
        d_norm += rt * rt;
    }
    d_norm = sqrt(d_norm);
    auto get_norm = [=](_megdnn_tensor_in norm) -> T_ACC {
        return sqrt(std::accumulate(
                norm.ptr<T_ACC>(), norm.ptr<T_ACC>() + total_elem, 0,
                [](T_ACC t1, T_ACC t2) -> T_ACC { return t1 + t2 * t2; }));
    };
    T_ACC p_norm = get_norm(lamb_param), trust_ratio = 1;
    if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) {
        trust_ratio = p_norm / d_norm;
    }
    for (size_t i = 0; i < total_elem; i++) {
        mt = m_t.ptr<T_ACC>()[i];
        vt = v_t.ptr<T_ACC>()[i];
        rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps);
        if (weight_decay != 0) {
            rt += lamb_param.ptr<T_ACC>()[i] * weight_decay;
        }
        new_param.ptr<T_ACC>()[i] = lamb_param.ptr<T_ACC>()[i] - lr * trust_ratio * rt;
    }
}

}  // namespace

namespace megdnn {
namespace naive {
void LAMBUpdateImpl::exec(
        _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param,
        _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
        _megdnn_tensor_out new_param, _megdnn_workspace workspace) {
    check_exec(
            m_t_1.layout, v_t_1.layout, lamb_param.layout, grad.layout, m_t.layout,
            v_t.layout, new_param.layout, workspace.size);
#define cb(DType)                                                               \
    if (grad.layout.dtype == DType()) {                                         \
        MEGDNN_DISPATCH_CPU_KERN_OPR(update<typename DTypeTrait<DType>::ctype>( \
                m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, param())); \
        return;                                                                 \
    }
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
    megdnn_throw("bad dtype");
}
}  // namespace naive

}  // namespace megdnn