optimizer.cuh.h 19.5 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17 18
#ifdef PADDLE_WITH_HETERPS

#if defined(PADDLE_WITH_CUDA)
Y
yaoxuefeng 已提交
19
#include <curand_kernel.h>
20
#endif
T
Thunderbrook 已提交
21
#include <vector>
22

T
Thunderbrook 已提交
23
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
24
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
25 26 27 28

namespace paddle {
namespace framework {

29
#if defined(PADDLE_WITH_CUDA)
D
danleifeng 已提交
30

T
Thunderbrook 已提交
31 32
class Optimizer {
 public:
D
danleifeng 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
  __host__ Optimizer(CommonFeatureValueAccessor feature_value_accessor) {
    feature_value_accessor_ = feature_value_accessor;
  }
  __host__ ~Optimizer() {}

  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
  }

  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     float* ptr,
                                     const float* grad) {}
T
Thunderbrook 已提交
48

D
danleifeng 已提交
49
  CommonFeatureValueAccessor feature_value_accessor_;
T
Thunderbrook 已提交
50

D
danleifeng 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
  size_t _embedding_dim;
  size_t _lr_embedding_dim;
};

class SparseAdagradOptimizer : public Optimizer {
 public:
  __host__ SparseAdagradOptimizer(
      CommonFeatureValueAccessor feature_value_accessor)
      : Optimizer(feature_value_accessor) {
    _lr_embedding_dim = 1;
    _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim();
  }

  __device__ void update_value_work(const OptimizerConfig& optimizer_config,
                                    int n,
                                    float* w,
                                    float* sgd,  // NOLINT
                                    const float* g,
                                    float scale) {
    float& g2sum = sgd[G2SumIndex()];
    double add_g2sum = 0;
    double ratio = optimizer_config.mf_learning_rate *
                   sqrt(optimizer_config.mf_initial_g2sum /
                        (optimizer_config.mf_initial_g2sum + g2sum));
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      w[i] += scaled_grad * ratio;

      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
      add_g2sum += scaled_grad * scaled_grad;
    }

    g2sum += add_g2sum / n;
  }

  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
  }
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     float* ptr,
                                     const float* grad) {
    float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()];
    float g_click =
        grad[feature_value_accessor_.common_push_value.ClickIndex()];

    ptr[feature_value_accessor_.common_feature_value.SlotIndex()] =
        grad[feature_value_accessor_.common_push_value.SlotIndex()];
    ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show;
    ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click;
    ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] +=
        optimizer_config.nonclk_coeff * (g_show - g_click) +
        optimizer_config.clk_coeff * g_click;

    update_value_work(
        optimizer_config,
        1,
        ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(),
        ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(),
        grad + feature_value_accessor_.common_push_value.EmbedGIndex(),
        g_show);

    int mf_dim =
        int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]);
    if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) {
      if (optimizer_config.mf_create_thresholds <=
          optimizer_config.nonclk_coeff *
                  (ptr[feature_value_accessor_.common_feature_value
                           .ShowIndex()] -
                   ptr[feature_value_accessor_.common_feature_value
                           .ClickIndex()]) +
              optimizer_config.clk_coeff *
                  ptr[feature_value_accessor_.common_feature_value
                          .ClickIndex()]) {
        ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] =
            feature_value_accessor_.common_feature_value.MFSize(mf_dim) /
            sizeof(float);

        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
        for (int i = 0; i < mf_dim; ++i) {
          ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] =
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
        }
      }
    } else {
      update_value_work(
          optimizer_config,
          mf_dim,
          ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(),
          ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(),
          grad + feature_value_accessor_.common_push_value.EmbedxGIndex(),
          g_show);
    }
  }

  __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); }
  __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim; }
  __host__ __device__ size_t EmbedxDim() { return _embedding_dim; }
  __host__ __device__ size_t G2SumIndex() { return 0; }
  __host__ __device__ size_t EmbedxG2SumIndex() { return 0; }
};

class SparseAdamOptimizer : public Optimizer {
 public:
  __host__ SparseAdamOptimizer(
      CommonFeatureValueAccessor feature_value_accessor)
      : Optimizer(feature_value_accessor) {
    _lr_embedding_dim = 1;
    _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim();
  }
T
Thunderbrook 已提交
169

Z
zmxdream 已提交
170
  __device__ void update_lr(const OptimizerConfig& optimizer_config,
D
danleifeng 已提交
171 172 173 174
                            int n,
                            float* w,
                            float* sgd,
                            const float* g,
175
                            float scale) {
D
danleifeng 已提交
176 177 178 179
    float* moment1 = sgd + GSumIndex();
    float* moment2 = sgd + G2SumIndex();
    float* beta1_pow = sgd + Beta1PowIndex();
    float* beta2_pow = sgd + Beta2PowIndex();
T
Thunderbrook 已提交
180

D
danleifeng 已提交
181 182
    float beta1_pow_ = *beta1_pow;
    float beta2_pow_ = *beta2_pow;
T
Thunderbrook 已提交
183

D
danleifeng 已提交
184 185 186 187 188
    float epsilon = 1e-08;
    double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) /
                   (1.0 - beta1_pow_);
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;
T
Thunderbrook 已提交
189

D
danleifeng 已提交
190 191 192 193 194 195 196 197 198 199 200 201
      double new_moment1 =
          optimizer_config.beta1_decay_rate * moment1[i] +
          (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad;
      double new_moment2 =
          optimizer_config.beta2_decay_rate * moment2[i] +
          (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad;
      w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon));

      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
T
Thunderbrook 已提交
202

D
danleifeng 已提交
203 204 205 206 207
      moment1[i] = new_moment1;
      moment2[i] = new_moment2;
    }
    (*beta1_pow) *= optimizer_config.beta1_decay_rate;
    (*beta2_pow) *= optimizer_config.beta2_decay_rate;
T
Thunderbrook 已提交
208 209
  }

210 211
  __device__ void update_mf(const OptimizerConfig& optimizer_config,
                            int n,
Z
zmxdream 已提交
212
                            float* w,
D
danleifeng 已提交
213
                            float* sgd,
214 215
                            const float* g,
                            float scale) {
D
danleifeng 已提交
216 217 218 219 220 221 222 223 224 225 226
    float* moment1 = sgd + EmbedxGSumIndex();
    float* moment2 = sgd + EmbedxG2SumIndex();
    float* beta1_pow = sgd + EmbedxBeta1PowIndex();
    float* beta2_pow = sgd + EmbedxBeta2PowIndex();

    float beta1_pow_ = *beta1_pow;
    float beta2_pow_ = *beta2_pow;

    float epsilon = 1e-08;
    double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) /
                   (1.0 - beta1_pow_);
T
Thunderbrook 已提交
227 228 229
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

D
danleifeng 已提交
230 231 232 233 234 235 236
      double new_moment1 =
          optimizer_config.beta1_decay_rate * moment1[i] +
          (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad;
      double new_moment2 =
          optimizer_config.beta2_decay_rate * moment2[i] +
          (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad;
      w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon));
T
Thunderbrook 已提交
237

Z
zmxdream 已提交
238 239 240 241
      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;
T
Thunderbrook 已提交
242

D
danleifeng 已提交
243 244 245 246 247
      moment1[i] = new_moment1;
      moment2[i] = new_moment2;
    }
    (*beta1_pow) *= optimizer_config.beta1_decay_rate;
    (*beta2_pow) *= optimizer_config.beta2_decay_rate;
T
Thunderbrook 已提交
248
  }
249

Z
zmxdream 已提交
250
  __device__ void update_value(const OptimizerConfig& optimizer_config,
D
danleifeng 已提交
251 252 253 254 255 256 257 258 259 260 261
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
  }
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
                                     float* ptr,
                                     const float* grad) {
    float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()];
    float g_click =
        grad[feature_value_accessor_.common_push_value.ClickIndex()];
T
Thunderbrook 已提交
262

D
danleifeng 已提交
263 264 265 266 267 268 269
    ptr[feature_value_accessor_.common_feature_value.SlotIndex()] =
        grad[feature_value_accessor_.common_push_value.SlotIndex()];
    ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show;
    ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click;
    ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] +=
        optimizer_config.nonclk_coeff * (g_show - g_click) +
        optimizer_config.clk_coeff * g_click;
T
Thunderbrook 已提交
270

D
danleifeng 已提交
271 272 273 274 275 276 277 278 279 280
    update_lr(
        optimizer_config,
        1,
        ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(),
        ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(),
        grad + feature_value_accessor_.common_push_value.EmbedGIndex(),
        g_show);
    int mf_dim =
        int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]);
    if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) {
Z
zmxdream 已提交
281
      if (optimizer_config.mf_create_thresholds <=
D
danleifeng 已提交
282 283 284 285 286 287 288 289 290 291 292 293
          optimizer_config.nonclk_coeff *
                  (ptr[feature_value_accessor_.common_feature_value
                           .ShowIndex()] -
                   ptr[feature_value_accessor_.common_feature_value
                           .ClickIndex()]) +
              optimizer_config.clk_coeff *
                  ptr[feature_value_accessor_.common_feature_value
                          .ClickIndex()]) {
        ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] =
            feature_value_accessor_.common_feature_value.MFSize(mf_dim) /
            sizeof(float);

Y
yaoxuefeng 已提交
294 295 296
        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
D
danleifeng 已提交
297 298
        for (int i = 0; i < mf_dim; ++i) {
          ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] =
Z
zmxdream 已提交
299
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
T
Thunderbrook 已提交
300
        }
D
danleifeng 已提交
301 302 303 304
        ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() +
            EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate;
        ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() +
            EmbedxBeta2PowIndex()] = optimizer_config.beta2_decay_rate;
T
Thunderbrook 已提交
305 306
      }
    } else {
D
danleifeng 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
      update_mf(
          optimizer_config,
          mf_dim,
          ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(),
          ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(),
          grad + feature_value_accessor_.common_push_value.EmbedxGIndex(),
          g_show);
    }
    // printf("EmbedxGIndex: %f, mf_gsum: %f, ",
    // feature_value_accessor_.common_push_value.EmbedxGIndex(),
    //          ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex()]);
  }

  __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); }
  __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim * 2 + 2; }
  __host__ __device__ size_t EmbedxDim() { return _embedding_dim * 2 + 2; }
  __host__ __device__ size_t GSumIndex() { return 0; }
  __host__ __device__ size_t G2SumIndex() {
    return GSumIndex() + _lr_embedding_dim;
  }
  __host__ __device__ size_t Beta1PowIndex() {
    return G2SumIndex() + _lr_embedding_dim;
  }
  __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
  __host__ __device__ size_t EmbedxGSumIndex() { return 0; }
  __host__ __device__ size_t EmbedxG2SumIndex() {
    return EmbedxGSumIndex() + _embedding_dim;
  }
  __host__ __device__ size_t EmbedxBeta1PowIndex() {
    return EmbedxG2SumIndex() + _embedding_dim;
  }
  __host__ __device__ size_t EmbedxBeta2PowIndex() {
    return EmbedxBeta1PowIndex() + 1;
  }
};

class SparseAdamSharedOptimizer : public Optimizer {
 public:
  __host__ SparseAdamSharedOptimizer(
      CommonFeatureValueAccessor feature_value_accessor)
      : Optimizer(feature_value_accessor) {
    _lr_embedding_dim = 1;
    _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim();
  }

  __device__ void update_value_work(const OptimizerConfig& optimizer_config,
                                    int n,
                                    float* w,
                                    float* sgd,
                                    const float* g,
                                    float scale) {
    float* moment1 = sgd + GSumIndex();
    float* moment2 = sgd + G2SumIndex();
    float* beta1_pow = sgd + Beta1PowIndex();
    float* beta2_pow = sgd + Beta2PowIndex();

    float beta1_pow_ = *beta1_pow;
    float beta2_pow_ = *beta2_pow;
    float moment1_ = *moment1;
    float moment2_ = *moment2;
    float epsilon = 1e-08;
    double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) /
                   (1.0 - beta1_pow_);

    double sum_mom1 = 0.0;
    double sum_mom2 = 0.0;
    for (int i = 0; i < n; ++i) {
      double scaled_grad = g[i] / scale;

      double new_moment1 =
          optimizer_config.beta1_decay_rate * moment1_ +
          (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad;
      double new_moment2 =
          optimizer_config.beta2_decay_rate * moment2_ +
          (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad;
      w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon));

      if (w[i] < optimizer_config.mf_min_bound)
        w[i] = optimizer_config.mf_min_bound;
      if (w[i] > optimizer_config.mf_max_bound)
        w[i] = optimizer_config.mf_max_bound;

      sum_mom1 += new_moment1;
      sum_mom2 += new_moment2;
T
Thunderbrook 已提交
391
    }
D
danleifeng 已提交
392 393 394 395 396 397 398 399 400 401 402 403

    (*moment1) = sum_mom1 / n;
    (*moment2) = sum_mom2 / n;
    (*beta1_pow) *= optimizer_config.beta1_decay_rate;
    (*beta2_pow) *= optimizer_config.beta2_decay_rate;
  }

  __device__ void update_value(const OptimizerConfig& optimizer_config,
                               float& val,  // NOLINT
                               const float& grad) {
    printf(
        "Warning: update_value will not used. Please use dy_mf_update_value\n");
T
Thunderbrook 已提交
404
  }
405

Z
zmxdream 已提交
406
  __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config,
D
danleifeng 已提交
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
                                     float* ptr,
                                     const float* grad) {
    float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()];
    float g_click =
        grad[feature_value_accessor_.common_push_value.ClickIndex()];

    ptr[feature_value_accessor_.common_feature_value.SlotIndex()] =
        grad[feature_value_accessor_.common_push_value.SlotIndex()];
    ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show;
    ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click;
    ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] +=
        optimizer_config.nonclk_coeff * (g_show - g_click) +
        optimizer_config.clk_coeff * g_click;

    update_value_work(
        optimizer_config,
        1,
        ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(),
        ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(),
        grad + feature_value_accessor_.common_push_value.EmbedGIndex(),
        g_show);
    int mf_dim =
        int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]);
    if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) {
Z
zmxdream 已提交
431
      if (optimizer_config.mf_create_thresholds <=
D
danleifeng 已提交
432 433 434 435 436 437 438 439 440 441 442
          optimizer_config.nonclk_coeff *
                  (ptr[feature_value_accessor_.common_feature_value
                           .ShowIndex()] -
                   ptr[feature_value_accessor_.common_feature_value
                           .ClickIndex()]) +
              optimizer_config.clk_coeff *
                  ptr[feature_value_accessor_.common_feature_value
                          .ClickIndex()]) {
        ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] =
            feature_value_accessor_.common_feature_value.MFSize(mf_dim) /
            sizeof(float);
443 444 445 446

        int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
        curandState state;
        curand_init(clock64(), tid_x, 0, &state);
D
danleifeng 已提交
447 448
        for (int i = 0; i < mf_dim; ++i) {
          ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] =
Z
zmxdream 已提交
449
              (curand_uniform(&state)) * optimizer_config.mf_initial_range;
450
        }
D
danleifeng 已提交
451 452 453 454
        ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() +
            EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate;
        ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() +
            EmbedxBeta2PowIndex()] = optimizer_config.beta2_decay_rate;
455 456
      }
    } else {
D
danleifeng 已提交
457 458 459 460 461 462 463
      update_value_work(
          optimizer_config,
          mf_dim,
          ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(),
          ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(),
          grad + feature_value_accessor_.common_push_value.EmbedxGIndex(),
          g_show);
464 465
    }
  }
D
danleifeng 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483

  __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); }
  __host__ __device__ size_t EmbedDim() { return 4; }
  __host__ __device__ size_t EmbedxDim() { return 4; }
  __host__ __device__ size_t GSumIndex() { return 0; }
  __host__ __device__ size_t G2SumIndex() { return GSumIndex() + 1; }
  __host__ __device__ size_t Beta1PowIndex() { return G2SumIndex() + 1; }
  __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; }
  __host__ __device__ size_t EmbedxGSumIndex() { return 0; }
  __host__ __device__ size_t EmbedxG2SumIndex() {
    return EmbedxGSumIndex() + 1;
  }
  __host__ __device__ size_t EmbedxBeta1PowIndex() {
    return EmbedxG2SumIndex() + 1;
  }
  __host__ __device__ size_t EmbedxBeta2PowIndex() {
    return EmbedxBeta1PowIndex() + 1;
  }
T
Thunderbrook 已提交
484 485
};

486
#endif
T
Thunderbrook 已提交
487 488 489
}  // end namespace framework
}  // end namespace paddle
#endif