optimizer.cuh.h 5.0 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17 18
#ifdef PADDLE_WITH_HETERPS

#if defined(PADDLE_WITH_CUDA)
Y
yaoxuefeng 已提交
19
#include <curand_kernel.h>
20
#endif
T
Thunderbrook 已提交
21
#include <vector>
T
Thunderbrook 已提交
22
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
23
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
24 25 26 27

namespace paddle {
namespace framework {

28
#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
29 30 31 32 33 34 35 36 37
template <typename ValType, typename GradType>
class Optimizer {
 public:
  Optimizer() {}

  ~Optimizer() {}

  void initialize() {}

Z
zmxdream 已提交
38 39 40
  __device__ void update_lr(const OptimizerConfig& optimizer_config,
                            float& w,               // NOLINT
                            float& g2sum, float g,  // NOLINT
41
                            float scale) {
T
Thunderbrook 已提交
42
    double add_g2sum = 0;
Z
zmxdream 已提交
43 44 45
    double ratio = optimizer_config.learning_rate *
                   sqrt(optimizer_config.initial_g2sum /
                        (optimizer_config.initial_g2sum + g2sum));
T
Thunderbrook 已提交
46 47 48 49
    double scaled_grad = g / scale;

    w += scaled_grad * ratio;

Z
zmxdream 已提交
50 51
    if (w < optimizer_config.min_bound) w = optimizer_config.min_bound;
    if (w > optimizer_config.max_bound) w = optimizer_config.max_bound;
T
Thunderbrook 已提交
52

53
    add_g2sum += scaled_grad * scaled_grad;
T
Thunderbrook 已提交
54 55 56 57

    g2sum += add_g2sum;
  }

Z
zmxdream 已提交
58 59 60
  __device__ void update_mf(const OptimizerConfig& optimizer_config, int n,
                            float* w,
                            float& g2sum,  // NOLINT
61
                            const float* g, float scale) {
T
Thunderbrook 已提交
62
    double add_g2sum = 0;
Z
zmxdream 已提交
63 64 65
    double ratio = optimizer_config.mf_learning_rate *
                   sqrt(optimizer_config.mf_initial_g2sum /
                        (optimizer_config.mf_initial_g2sum + g2sum));
T
Thunderbrook 已提交
66 67 68 69 70
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      w[i] += scaled_grad * ratio;

Z
zmxdream 已提交
71 72 73 74
      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
75
      add_g2sum += scaled_grad * scaled_grad;
T
Thunderbrook 已提交
76 77 78 79
    }

    g2sum += add_g2sum / n;
  }
80

Z
zmxdream 已提交
81 82 83
  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               ValType& val,  // NOLINT
                               const GradType& grad) {
T
Thunderbrook 已提交
84 85 86
    val.slot = grad.slot;
    val.show += grad.show;
    val.clk += grad.clk;
Z
zmxdream 已提交
87 88
    val.delta_score += optimizer_config.nonclk_coeff * (grad.show - grad.clk) +
                       optimizer_config.clk_coeff * grad.clk;
T
Thunderbrook 已提交
89

Z
zmxdream 已提交
90
    update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show);
T
Thunderbrook 已提交
91 92

    if (val.mf_size == 0) {
Z
zmxdream 已提交
93 94 95
      if (optimizer_config.mf_create_thresholds <=
          optimizer_config.nonclk_coeff * (val.show - val.clk) +
              optimizer_config.clk_coeff * val.clk) {
T
Thunderbrook 已提交
96 97
        val.mf_size = MF_DIM + 1;
        val.mf[0] = 0;
Y
yaoxuefeng 已提交
98 99 100
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
T
Thunderbrook 已提交
101
        for (int i = 0; i < MF_DIM; ++i) {
T
Thunderbrook 已提交
102
          val.mf[i + 1] =
Z
zmxdream 已提交
103
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
T
Thunderbrook 已提交
104 105 106
        }
      }
    } else {
Z
zmxdream 已提交
107 108
      update_mf(optimizer_config, MF_DIM, &val.mf[1], val.mf[0], grad.mf_g,
                grad.show);
T
Thunderbrook 已提交
109 110
    }
  }
111

Z
zmxdream 已提交
112 113
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     ValType* ptr, const GradType& grad) {
114 115 116
    ptr->slot = grad.slot;
    ptr->show += grad.show;
    ptr->clk += grad.clk;
Z
zmxdream 已提交
117 118
    ptr->delta_score += optimizer_config.nonclk_coeff * (grad.show - grad.clk) +
                        optimizer_config.clk_coeff * grad.clk;
119

Z
zmxdream 已提交
120
    update_lr(optimizer_config, ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show);
121 122 123 124
    // use MF_DIM temporarily
    // ptr->mf_dim = grad.mf_dim;

    if (ptr->mf_size == 0) {
Z
zmxdream 已提交
125 126 127
      if (optimizer_config.mf_create_thresholds <=
          optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) +
              optimizer_config.clk_coeff * ptr->clk) {
Y
yaoxuefeng 已提交
128
        ptr->mf_size = ptr->mf_dim + 1;
129

Y
yaoxuefeng 已提交
130
        // ptr->mf_size = MF_DIM + 1;
131 132 133 134
        ptr->mf[0] = 0;
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
Y
yaoxuefeng 已提交
135
        for (int i = 0; i < ptr->mf_dim; ++i) {
136
          ptr->mf[i + 1] =
Z
zmxdream 已提交
137
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
138 139 140
        }
      }
    } else {
Y
yaoxuefeng 已提交
141 142
      update_mf(optimizer_config, ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0],
                grad.mf_g,
143 144 145
                grad.show);  // for local test
    }
  }
T
Thunderbrook 已提交
146 147
};

148
#endif
T
Thunderbrook 已提交
149 150 151
}  // end namespace framework
}  // end namespace paddle
#endif