optimizer.cuh.h 5.0 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17 18
#ifdef PADDLE_WITH_HETERPS

#if defined(PADDLE_WITH_CUDA)
Y
yaoxuefeng 已提交
19
#include <curand_kernel.h>
20
#endif
T
Thunderbrook 已提交
21
#include <vector>
22

T
Thunderbrook 已提交
23
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
24
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
25 26 27 28

namespace paddle {
namespace framework {

29
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
30 31 32 33 34 35 36 37 38
template <typename ValType, typename GradType>
class Optimizer {
 public:
  Optimizer() {}

  ~Optimizer() {}

  void initialize() {}

Z
zmxdream 已提交
39 40 41
  __device__ void update_lr(const OptimizerConfig& optimizer_config,
                            float& w,               // NOLINT
                            float& g2sum, float g,  // NOLINT
42
                            float scale) {
T
Thunderbrook 已提交
43
    double add_g2sum = 0;
Z
zmxdream 已提交
44 45 46
    double ratio = optimizer_config.learning_rate *
                   sqrt(optimizer_config.initial_g2sum /
                        (optimizer_config.initial_g2sum + g2sum));
T
Thunderbrook 已提交
47 48 49 50
    double scaled_grad = g / scale;

    w += scaled_grad * ratio;

Z
zmxdream 已提交
51 52
    if (w < optimizer_config.min_bound) w = optimizer_config.min_bound;
    if (w > optimizer_config.max_bound) w = optimizer_config.max_bound;
T
Thunderbrook 已提交
53

54
    add_g2sum += scaled_grad * scaled_grad;
T
Thunderbrook 已提交
55 56 57 58

    g2sum += add_g2sum;
  }

Z
zmxdream 已提交
59 60 61
  __device__ void update_mf(const OptimizerConfig& optimizer_config, int n,
                            float* w,
                            float& g2sum,  // NOLINT
62
                            const float* g, float scale) {
T
Thunderbrook 已提交
63
    double add_g2sum = 0;
Z
zmxdream 已提交
64 65 66
    double ratio = optimizer_config.mf_learning_rate *
                   sqrt(optimizer_config.mf_initial_g2sum /
                        (optimizer_config.mf_initial_g2sum + g2sum));
T
Thunderbrook 已提交
67 68 69 70 71
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      w[i] += scaled_grad * ratio;

Z
zmxdream 已提交
72 73 74 75
      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
76
      add_g2sum += scaled_grad * scaled_grad;
T
Thunderbrook 已提交
77 78 79 80
    }

    g2sum += add_g2sum / n;
  }
81

Z
zmxdream 已提交
82 83 84
  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               ValType& val,  // NOLINT
                               const GradType& grad) {
T
Thunderbrook 已提交
85 86 87
    val.slot = grad.slot;
    val.show += grad.show;
    val.clk += grad.clk;
Z
zmxdream 已提交
88 89
    val.delta_score += optimizer_config.nonclk_coeff * (grad.show - grad.clk) +
                       optimizer_config.clk_coeff * grad.clk;
T
Thunderbrook 已提交
90

Z
zmxdream 已提交
91
    update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show);
T
Thunderbrook 已提交
92 93

    if (val.mf_size == 0) {
Z
zmxdream 已提交
94 95 96
      if (optimizer_config.mf_create_thresholds <=
          optimizer_config.nonclk_coeff * (val.show - val.clk) +
              optimizer_config.clk_coeff * val.clk) {
T
Thunderbrook 已提交
97 98
        val.mf_size = MF_DIM + 1;
        val.mf[0] = 0;
Y
yaoxuefeng 已提交
99 100 101
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
T
Thunderbrook 已提交
102
        for (int i = 0; i < MF_DIM; ++i) {
T
Thunderbrook 已提交
103
          val.mf[i + 1] =
Z
zmxdream 已提交
104
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
T
Thunderbrook 已提交
105 106 107
        }
      }
    } else {
Z
zmxdream 已提交
108 109
      update_mf(optimizer_config, MF_DIM, &val.mf[1], val.mf[0], grad.mf_g,
                grad.show);
T
Thunderbrook 已提交
110 111
    }
  }
112

Z
zmxdream 已提交
113 114
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     ValType* ptr, const GradType& grad) {
115 116 117
    ptr->slot = grad.slot;
    ptr->show += grad.show;
    ptr->clk += grad.clk;
Z
zmxdream 已提交
118 119
    ptr->delta_score += optimizer_config.nonclk_coeff * (grad.show - grad.clk) +
                        optimizer_config.clk_coeff * grad.clk;
120

Z
zmxdream 已提交
121
    update_lr(optimizer_config, ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show);
122 123 124 125
    // use MF_DIM temporarily
    // ptr->mf_dim = grad.mf_dim;

    if (ptr->mf_size == 0) {
Z
zmxdream 已提交
126 127 128
      if (optimizer_config.mf_create_thresholds <=
          optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) +
              optimizer_config.clk_coeff * ptr->clk) {
Y
yaoxuefeng 已提交
129
        ptr->mf_size = ptr->mf_dim + 1;
130

Y
yaoxuefeng 已提交
131
        // ptr->mf_size = MF_DIM + 1;
132 133 134 135
        ptr->mf[0] = 0;
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
Y
yaoxuefeng 已提交
136
        for (int i = 0; i < ptr->mf_dim; ++i) {
137
          ptr->mf[i + 1] =
Z
zmxdream 已提交
138
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
139 140 141
        }
      }
    } else {
Y
yaoxuefeng 已提交
142 143
      update_mf(optimizer_config, ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0],
                grad.mf_g,
144 145 146
                grad.show);  // for local test
    }
  }
T
Thunderbrook 已提交
147 148
};

149
#endif
T
Thunderbrook 已提交
150 151 152
}  // end namespace framework
}  // end namespace paddle
#endif