optimizer.cuh.h 4.4 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
yaoxuefeng 已提交
16
#include <curand_kernel.h>
T
Thunderbrook 已提交
17
#include <vector>
T
Thunderbrook 已提交
18 19 20
#include "optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"

T
Thunderbrook 已提交
21
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

namespace paddle {
namespace framework {

template <typename ValType, typename GradType>
class Optimizer {
 public:
  Optimizer() {}

  ~Optimizer() {}

  void initialize() {}

  __device__ void update_lr(float& w, float& g2sum, float g, float scale) {
    double add_g2sum = 0;
    double ratio = optimizer_config::learning_rate *
                   sqrt(optimizer_config::initial_g2sum /
                        (optimizer_config::initial_g2sum + g2sum));
    double scaled_grad = g / scale;

    w += scaled_grad * ratio;

    if (w < optimizer_config::min_bound) w = optimizer_config::min_bound;
    if (w > optimizer_config::max_bound) w = optimizer_config::max_bound;

47
    add_g2sum += scaled_grad * scaled_grad;
T
Thunderbrook 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

    g2sum += add_g2sum;
  }

  __device__ void update_mf(int n, float* w, float& g2sum, const float* g,
                            float scale) {
    double add_g2sum = 0;
    double ratio = optimizer_config::mf_learning_rate *
                   sqrt(optimizer_config::mf_initial_g2sum /
                        (optimizer_config::mf_initial_g2sum + g2sum));
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      w[i] += scaled_grad * ratio;

      if (w[i] < optimizer_config::mf_min_bound)
        w[i] = optimizer_config::mf_min_bound;
      if (w[i] > optimizer_config::mf_max_bound)
        w[i] = optimizer_config::mf_max_bound;
67
      add_g2sum += scaled_grad * scaled_grad;
T
Thunderbrook 已提交
68 69 70 71 72 73 74 75
    }

    g2sum += add_g2sum / n;
  }
  __device__ void update_value(ValType& val, const GradType& grad) {
    val.slot = grad.slot;
    val.show += grad.show;
    val.clk += grad.clk;
76 77
    val.delta_score += optimizer_config::nonclk_coeff * (grad.show - grad.clk) +
                       optimizer_config::clk_coeff * grad.clk;
T
Thunderbrook 已提交
78

79
    update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);
T
Thunderbrook 已提交
80 81 82 83 84 85 86

    if (val.mf_size == 0) {
      if (optimizer_config::mf_create_thresholds <=
          optimizer_config::nonclk_coeff * (val.show - val.clk) +
              optimizer_config::clk_coeff * val.clk) {
        val.mf_size = MF_DIM + 1;
        val.mf[0] = 0;
Y
yaoxuefeng 已提交
87 88 89
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
T
Thunderbrook 已提交
90
        for (int i = 0; i < MF_DIM; ++i) {
T
Thunderbrook 已提交
91 92
          val.mf[i + 1] =
              (curand_uniform(&state)) * optimizer_config::mf_initial_range;
T
Thunderbrook 已提交
93 94 95
        }
      }
    } else {
96
      update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
T
Thunderbrook 已提交
97 98
    }
  }
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

  __device__ void dy_mf_update_value(ValType* ptr, const GradType& grad) {
    ptr->slot = grad.slot;
    ptr->show += grad.show;
    ptr->clk += grad.clk;
    ptr->delta_score +=
        optimizer_config::nonclk_coeff * (grad.show - grad.clk) +
        optimizer_config::clk_coeff * grad.clk;

    update_lr(ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show);
    // use MF_DIM temporarily
    // ptr->mf_dim = grad.mf_dim;

    if (ptr->mf_size == 0) {
      if (optimizer_config::mf_create_thresholds <=
          optimizer_config::nonclk_coeff * (ptr->show - ptr->clk) +
              optimizer_config::clk_coeff * ptr->clk) {
        // ptr->mf_size = ptr->mf_dim + 1;

        ptr->mf_size = MF_DIM + 1;
        ptr->mf[0] = 0;
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
        for (int i = 0; i < MF_DIM; ++i) {
          ptr->mf[i + 1] =
              (curand_uniform(&state)) * optimizer_config::mf_initial_range;
        }
      }
    } else {
      update_mf(MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g,
                grad.show);  // for local test
    }
  }
T
Thunderbrook 已提交
133 134 135 136 137
};

}  // end namespace framework
}  // end namespace paddle
#endif