optimizer.cuh.h 18.4 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17 18
#ifdef PADDLE_WITH_HETERPS

#if defined(PADDLE_WITH_CUDA)
Y
yaoxuefeng 已提交
19
#include <curand_kernel.h>
20
#endif
T
Thunderbrook 已提交
21
#include <vector>
T
Thunderbrook 已提交
22
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
23
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
24 25 26 27

namespace paddle {
namespace framework {

28
#if defined(PADDLE_WITH_CUDA)
D
danleifeng 已提交
29

D
danleifeng 已提交
30 31
template <typename GPUAccessor>
class SparseAdagradOptimizer {
D
danleifeng 已提交
32
 public:
D
danleifeng 已提交
33 34 35
  SparseAdagradOptimizer() {}
  SparseAdagradOptimizer(GPUAccessor gpu_accessor) {
    gpu_accessor_ = gpu_accessor;
D
danleifeng 已提交
36
    _lr_embedding_dim = 1;
D
danleifeng 已提交
37
    _embedding_dim = gpu_accessor_.common_feature_value.EmbedWDim();
D
danleifeng 已提交
38 39
  }

D
danleifeng 已提交
40 41
  ~SparseAdagradOptimizer() {}

D
danleifeng 已提交
42 43 44 45 46
  __device__ void update_value_work(const OptimizerConfig& optimizer_config,
                                    int n,
                                    float* w,
                                    float* sgd,  // NOLINT
                                    const float* g,
D
danleifeng 已提交
47 48
                                    float scale,
                                    float slot) {
D
danleifeng 已提交
49 50
    float& g2sum = sgd[G2SumIndex()];
    double add_g2sum = 0;
D
danleifeng 已提交
51 52 53 54 55 56 57 58

    float learning_rate = optimizer_config.mf_learning_rate;
    if (slot != optimizer_config.nodeid_slot) {
      learning_rate = optimizer_config.feature_learning_rate;
    }
    double ratio =
        learning_rate * sqrt(optimizer_config.mf_initial_g2sum /
                             (optimizer_config.mf_initial_g2sum + g2sum));
D
danleifeng 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      w[i] += scaled_grad * ratio;

      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
      add_g2sum += scaled_grad * scaled_grad;
    }

    g2sum += add_g2sum / n;
  }

  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
  }
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     float* ptr,
                                     const float* grad) {
D
danleifeng 已提交
83 84 85 86 87 88 89 90
    float g_show = grad[gpu_accessor_.common_push_value.ShowIndex()];
    float g_click = grad[gpu_accessor_.common_push_value.ClickIndex()];

    ptr[gpu_accessor_.common_feature_value.SlotIndex()] =
        grad[gpu_accessor_.common_push_value.SlotIndex()];
    ptr[gpu_accessor_.common_feature_value.ShowIndex()] += g_show;
    ptr[gpu_accessor_.common_feature_value.ClickIndex()] += g_click;
    ptr[gpu_accessor_.common_feature_value.DeltaScoreIndex()] +=
D
danleifeng 已提交
91 92
        optimizer_config.nonclk_coeff * (g_show - g_click) +
        optimizer_config.clk_coeff * g_click;
D
danleifeng 已提交
93
    float slot = ptr[gpu_accessor_.common_feature_value.SlotIndex()];
D
danleifeng 已提交
94 95 96 97

    update_value_work(
        optimizer_config,
        1,
D
danleifeng 已提交
98 99 100 101 102 103 104 105
        ptr + gpu_accessor_.common_feature_value.EmbedWIndex(),
        ptr + gpu_accessor_.common_feature_value.EmbedG2SumIndex(),
        grad + gpu_accessor_.common_push_value.EmbedGIndex(),
        g_show,
        slot);

    int mf_dim = int(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]);
    if (ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] == 0) {
D
danleifeng 已提交
106 107
      if (optimizer_config.mf_create_thresholds <=
          optimizer_config.nonclk_coeff *
D
danleifeng 已提交
108 109
                  (ptr[gpu_accessor_.common_feature_value.ShowIndex()] -
                   ptr[gpu_accessor_.common_feature_value.ClickIndex()]) +
D
danleifeng 已提交
110
              optimizer_config.clk_coeff *
D
danleifeng 已提交
111 112 113
                  ptr[gpu_accessor_.common_feature_value.ClickIndex()]) {
        ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] =
            gpu_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float);
D
danleifeng 已提交
114 115 116 117 118

        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
        for (int i = 0; i < mf_dim; ++i) {
D
danleifeng 已提交
119
          ptr[gpu_accessor_.common_feature_value.EmbedxWIndex() + i] =
D
danleifeng 已提交
120 121 122 123 124 125 126
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
        }
      }
    } else {
      update_value_work(
          optimizer_config,
          mf_dim,
D
danleifeng 已提交
127 128 129 130 131
          ptr + gpu_accessor_.common_feature_value.EmbedxWIndex(),
          ptr + gpu_accessor_.common_feature_value.EmbedxG2SumIndex(),
          grad + gpu_accessor_.common_push_value.EmbedxGIndex(),
          g_show,
          slot);
D
danleifeng 已提交
132 133 134 135 136 137 138 139
    }
  }

  __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); }
  __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim; }
  __host__ __device__ size_t EmbedxDim() { return _embedding_dim; }
  __host__ __device__ size_t G2SumIndex() { return 0; }
  __host__ __device__ size_t EmbedxG2SumIndex() { return 0; }
D
danleifeng 已提交
140 141 142 143 144

 private:
  GPUAccessor gpu_accessor_;
  size_t _embedding_dim;
  size_t _lr_embedding_dim;
D
danleifeng 已提交
145 146
};

D
danleifeng 已提交
147 148
template <typename GPUAccessor>
class SparseAdamOptimizer {
D
danleifeng 已提交
149
 public:
D
danleifeng 已提交
150 151 152
  SparseAdamOptimizer() {}
  SparseAdamOptimizer(GPUAccessor gpu_accessor) {
    gpu_accessor_ = gpu_accessor;
D
danleifeng 已提交
153
    _lr_embedding_dim = 1;
D
danleifeng 已提交
154
    _embedding_dim = gpu_accessor_.common_feature_value.EmbedWDim();
D
danleifeng 已提交
155
  }
T
Thunderbrook 已提交
156

D
danleifeng 已提交
157 158
  ~SparseAdamOptimizer() {}

Z
zmxdream 已提交
159
  __device__ void update_lr(const OptimizerConfig& optimizer_config,
D
danleifeng 已提交
160 161 162 163
                            int n,
                            float* w,
                            float* sgd,
                            const float* g,
164
                            float scale) {
D
danleifeng 已提交
165 166 167 168
    float* moment1 = sgd + GSumIndex();
    float* moment2 = sgd + G2SumIndex();
    float* beta1_pow = sgd + Beta1PowIndex();
    float* beta2_pow = sgd + Beta2PowIndex();
T
Thunderbrook 已提交
169

D
danleifeng 已提交
170 171
    float beta1_pow_ = *beta1_pow;
    float beta2_pow_ = *beta2_pow;
T
Thunderbrook 已提交
172

D
danleifeng 已提交
173 174 175 176 177
    float epsilon = 1e-08;
    double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) /
                   (1.0 - beta1_pow_);
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;
T
Thunderbrook 已提交
178

D
danleifeng 已提交
179 180 181 182 183 184 185 186 187 188 189 190
      double new_moment1 =
          optimizer_config.beta1_decay_rate * moment1[i] +
          (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad;
      double new_moment2 =
          optimizer_config.beta2_decay_rate * moment2[i] +
          (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad;
      w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon));

      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
T
Thunderbrook 已提交
191

D
danleifeng 已提交
192 193 194 195 196
      moment1[i] = new_moment1;
      moment2[i] = new_moment2;
    }
    (*beta1_pow) *= optimizer_config.beta1_decay_rate;
    (*beta2_pow) *= optimizer_config.beta2_decay_rate;
T
Thunderbrook 已提交
197 198
  }

199 200
  __device__ void update_mf(const OptimizerConfig& optimizer_config,
                            int n,
Z
zmxdream 已提交
201
                            float* w,
D
danleifeng 已提交
202
                            float* sgd,
203 204
                            const float* g,
                            float scale) {
D
danleifeng 已提交
205 206 207 208 209 210 211 212 213 214 215
    float* moment1 = sgd + EmbedxGSumIndex();
    float* moment2 = sgd + EmbedxG2SumIndex();
    float* beta1_pow = sgd + EmbedxBeta1PowIndex();
    float* beta2_pow = sgd + EmbedxBeta2PowIndex();

    float beta1_pow_ = *beta1_pow;
    float beta2_pow_ = *beta2_pow;

    float epsilon = 1e-08;
    double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) /
                   (1.0 - beta1_pow_);
T
Thunderbrook 已提交
216 217 218
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

D
danleifeng 已提交
219 220 221 222 223 224 225
      double new_moment1 =
          optimizer_config.beta1_decay_rate * moment1[i] +
          (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad;
      double new_moment2 =
          optimizer_config.beta2_decay_rate * moment2[i] +
          (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad;
      w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon));
T
Thunderbrook 已提交
226

Z
zmxdream 已提交
227 228 229 230
      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
T
Thunderbrook 已提交
231

D
danleifeng 已提交
232 233 234 235 236
      moment1[i] = new_moment1;
      moment2[i] = new_moment2;
    }
    (*beta1_pow) *= optimizer_config.beta1_decay_rate;
    (*beta2_pow) *= optimizer_config.beta2_decay_rate;
T
Thunderbrook 已提交
237
  }
238

Z
zmxdream 已提交
239
  __device__ void update_value(const OptimizerConfig& optimizer_config,
D
danleifeng 已提交
240 241 242 243 244 245 246 247
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
  }
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     float* ptr,
                                     const float* grad) {
D
danleifeng 已提交
248 249 250 251 252 253 254 255
    float g_show = grad[gpu_accessor_.common_push_value.ShowIndex()];
    float g_click = grad[gpu_accessor_.common_push_value.ClickIndex()];

    ptr[gpu_accessor_.common_feature_value.SlotIndex()] =
        grad[gpu_accessor_.common_push_value.SlotIndex()];
    ptr[gpu_accessor_.common_feature_value.ShowIndex()] += g_show;
    ptr[gpu_accessor_.common_feature_value.ClickIndex()] += g_click;
    ptr[gpu_accessor_.common_feature_value.DeltaScoreIndex()] +=
D
danleifeng 已提交
256 257
        optimizer_config.nonclk_coeff * (g_show - g_click) +
        optimizer_config.clk_coeff * g_click;
T
Thunderbrook 已提交
258

D
danleifeng 已提交
259 260 261 262 263 264 265 266
    update_lr(optimizer_config,
              1,
              ptr + gpu_accessor_.common_feature_value.EmbedWIndex(),
              ptr + gpu_accessor_.common_feature_value.EmbedG2SumIndex(),
              grad + gpu_accessor_.common_push_value.EmbedGIndex(),
              g_show);
    int mf_dim = int(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]);
    if (ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] == 0) {
Z
zmxdream 已提交
267
      if (optimizer_config.mf_create_thresholds <=
D
danleifeng 已提交
268
          optimizer_config.nonclk_coeff *
D
danleifeng 已提交
269 270
                  (ptr[gpu_accessor_.common_feature_value.ShowIndex()] -
                   ptr[gpu_accessor_.common_feature_value.ClickIndex()]) +
D
danleifeng 已提交
271
              optimizer_config.clk_coeff *
D
danleifeng 已提交
272 273 274
                  ptr[gpu_accessor_.common_feature_value.ClickIndex()]) {
        ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] =
            gpu_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float);
D
danleifeng 已提交
275

Y
yaoxuefeng 已提交
276 277 278
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
D
danleifeng 已提交
279
        for (int i = 0; i < mf_dim; ++i) {
D
danleifeng 已提交
280
          ptr[gpu_accessor_.common_feature_value.EmbedxWIndex() + i] =
Z
zmxdream 已提交
281
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
T
Thunderbrook 已提交
282
        }
D
danleifeng 已提交
283
        ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex() +
D
danleifeng 已提交
284
            EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate;
D
danleifeng 已提交
285
        ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex() +
D
danleifeng 已提交
286
            EmbedxBeta2PowIndex()] = optimizer_config.beta2_decay_rate;
T
Thunderbrook 已提交
287 288
      }
    } else {
D
danleifeng 已提交
289 290 291 292 293 294
      update_mf(optimizer_config,
                mf_dim,
                ptr + gpu_accessor_.common_feature_value.EmbedxWIndex(),
                ptr + gpu_accessor_.common_feature_value.EmbedxG2SumIndex(),
                grad + gpu_accessor_.common_push_value.EmbedxGIndex(),
                g_show);
D
danleifeng 已提交
295 296
    }
    // printf("EmbedxGIndex: %f, mf_gsum: %f, ",
D
danleifeng 已提交
297 298
    // gpu_accessor_.common_push_value.EmbedxGIndex(),
    //          ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex()]);
D
danleifeng 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
  }

  __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); }
  __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim * 2 + 2; }
  __host__ __device__ size_t EmbedxDim() { return _embedding_dim * 2 + 2; }
  __host__ __device__ size_t GSumIndex() { return 0; }
  __host__ __device__ size_t G2SumIndex() {
    return GSumIndex() + _lr_embedding_dim;
  }
  __host__ __device__ size_t Beta1PowIndex() {
    return G2SumIndex() + _lr_embedding_dim;
  }
  __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
  __host__ __device__ size_t EmbedxGSumIndex() { return 0; }
  __host__ __device__ size_t EmbedxG2SumIndex() {
    return EmbedxGSumIndex() + _embedding_dim;
  }
  __host__ __device__ size_t EmbedxBeta1PowIndex() {
    return EmbedxG2SumIndex() + _embedding_dim;
  }
  __host__ __device__ size_t EmbedxBeta2PowIndex() {
    return EmbedxBeta1PowIndex() + 1;
  }
D
danleifeng 已提交
322 323 324 325 326

 private:
  GPUAccessor gpu_accessor_;
  size_t _embedding_dim;
  size_t _lr_embedding_dim;
D
danleifeng 已提交
327 328
};

D
danleifeng 已提交
329 330
template <typename GPUAccessor>
class SparseAdamSharedOptimizer {
D
danleifeng 已提交
331
 public:
D
danleifeng 已提交
332 333 334
  SparseAdamSharedOptimizer() {}
  SparseAdamSharedOptimizer(GPUAccessor gpu_accessor) {
    gpu_accessor_ = gpu_accessor;
D
danleifeng 已提交
335
    _lr_embedding_dim = 1;
D
danleifeng 已提交
336
    _embedding_dim = gpu_accessor_.common_feature_value.EmbedWDim();
D
danleifeng 已提交
337 338
  }

D
danleifeng 已提交
339 340
  ~SparseAdamSharedOptimizer() {}

D
danleifeng 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
  __device__ void update_value_work(const OptimizerConfig& optimizer_config,
                                    int n,
                                    float* w,
                                    float* sgd,
                                    const float* g,
                                    float scale) {
    float* moment1 = sgd + GSumIndex();
    float* moment2 = sgd + G2SumIndex();
    float* beta1_pow = sgd + Beta1PowIndex();
    float* beta2_pow = sgd + Beta2PowIndex();

    float beta1_pow_ = *beta1_pow;
    float beta2_pow_ = *beta2_pow;
    float moment1_ = *moment1;
    float moment2_ = *moment2;
    float epsilon = 1e-08;
    double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) /
                   (1.0 - beta1_pow_);

    double sum_mom1 = 0.0;
    double sum_mom2 = 0.0;
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      double new_moment1 =
          optimizer_config.beta1_decay_rate * moment1_ +
          (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad;
      double new_moment2 =
          optimizer_config.beta2_decay_rate * moment2_ +
          (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad;
      w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon));

      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;

      sum_mom1 += new_moment1;
      sum_mom2 += new_moment2;
T
Thunderbrook 已提交
380
    }
D
danleifeng 已提交
381 382 383 384 385 386 387 388 389 390 391 392

    (*moment1) = sum_mom1 / n;
    (*moment2) = sum_mom2 / n;
    (*beta1_pow) *= optimizer_config.beta1_decay_rate;
    (*beta2_pow) *= optimizer_config.beta2_decay_rate;
  }

  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
T
Thunderbrook 已提交
393
  }
394

Z
zmxdream 已提交
395
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
D
danleifeng 已提交
396 397
                                     float* ptr,
                                     const float* grad) {
D
danleifeng 已提交
398 399 400 401 402 403 404 405
    float g_show = grad[gpu_accessor_.common_push_value.ShowIndex()];
    float g_click = grad[gpu_accessor_.common_push_value.ClickIndex()];

    ptr[gpu_accessor_.common_feature_value.SlotIndex()] =
        grad[gpu_accessor_.common_push_value.SlotIndex()];
    ptr[gpu_accessor_.common_feature_value.ShowIndex()] += g_show;
    ptr[gpu_accessor_.common_feature_value.ClickIndex()] += g_click;
    ptr[gpu_accessor_.common_feature_value.DeltaScoreIndex()] +=
D
danleifeng 已提交
406 407 408 409 410 411
        optimizer_config.nonclk_coeff * (g_show - g_click) +
        optimizer_config.clk_coeff * g_click;

    update_value_work(
        optimizer_config,
        1,
D
danleifeng 已提交
412 413 414
        ptr + gpu_accessor_.common_feature_value.EmbedWIndex(),
        ptr + gpu_accessor_.common_feature_value.EmbedG2SumIndex(),
        grad + gpu_accessor_.common_push_value.EmbedGIndex(),
D
danleifeng 已提交
415
        g_show);
D
danleifeng 已提交
416 417
    int mf_dim = int(ptr[gpu_accessor_.common_feature_value.MfDimIndex()]);
    if (ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] == 0) {
Z
zmxdream 已提交
418
      if (optimizer_config.mf_create_thresholds <=
D
danleifeng 已提交
419
          optimizer_config.nonclk_coeff *
D
danleifeng 已提交
420 421
                  (ptr[gpu_accessor_.common_feature_value.ShowIndex()] -
                   ptr[gpu_accessor_.common_feature_value.ClickIndex()]) +
D
danleifeng 已提交
422
              optimizer_config.clk_coeff *
D
danleifeng 已提交
423 424 425
                  ptr[gpu_accessor_.common_feature_value.ClickIndex()]) {
        ptr[gpu_accessor_.common_feature_value.MfSizeIndex()] =
            gpu_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float);
426 427 428 429

        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
D
danleifeng 已提交
430
        for (int i = 0; i < mf_dim; ++i) {
D
danleifeng 已提交
431
          ptr[gpu_accessor_.common_feature_value.EmbedxWIndex() + i] =
Z
zmxdream 已提交
432
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
433
        }
D
danleifeng 已提交
434
        ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex() +
D
danleifeng 已提交
435
            EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate;
D
danleifeng 已提交
436
        ptr[gpu_accessor_.common_feature_value.EmbedxG2SumIndex() +
D
danleifeng 已提交
437
            EmbedxBeta2PowIndex()] = optimizer_config.beta2_decay_rate;
438 439
      }
    } else {
D
danleifeng 已提交
440 441 442
      update_value_work(
          optimizer_config,
          mf_dim,
D
danleifeng 已提交
443 444 445
          ptr + gpu_accessor_.common_feature_value.EmbedxWIndex(),
          ptr + gpu_accessor_.common_feature_value.EmbedxG2SumIndex(),
          grad + gpu_accessor_.common_push_value.EmbedxGIndex(),
D
danleifeng 已提交
446
          g_show);
447 448
    }
  }
D
danleifeng 已提交
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466

  __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); }
  __host__ __device__ size_t EmbedDim() { return 4; }
  __host__ __device__ size_t EmbedxDim() { return 4; }
  __host__ __device__ size_t GSumIndex() { return 0; }
  __host__ __device__ size_t G2SumIndex() { return GSumIndex() + 1; }
  __host__ __device__ size_t Beta1PowIndex() { return G2SumIndex() + 1; }
  __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
  __host__ __device__ size_t EmbedxGSumIndex() { return 0; }
  __host__ __device__ size_t EmbedxG2SumIndex() {
    return EmbedxGSumIndex() + 1;
  }
  __host__ __device__ size_t EmbedxBeta1PowIndex() {
    return EmbedxG2SumIndex() + 1;
  }
  __host__ __device__ size_t EmbedxBeta2PowIndex() {
    return EmbedxBeta1PowIndex() + 1;
  }
D
danleifeng 已提交
467 468 469 470 471

 private:
  GPUAccessor gpu_accessor_;
  size_t _embedding_dim;
  size_t _lr_embedding_dim;
T
Thunderbrook 已提交
472 473
};

474
#endif
T
Thunderbrook 已提交
475 476 477
}  // end namespace framework
}  // end namespace paddle
#endif