feature_value.h 44.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
Thunderbrook 已提交
17
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
18 19

#include <iostream>
D
danleifeng 已提交
20 21 22 23 24 25 26 27 28 29
#include <sstream>
#include <unordered_map>

#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#endif
T
Thunderbrook 已提交
30

P
pangengzheng 已提交
31 32 33 34 35
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h"  // NOLINT
#include "pslib.h"              // NOLINT
#endif

T
Thunderbrook 已提交
36 37 38 39 40
namespace paddle {
namespace framework {
#define MF_DIM 8

typedef uint64_t FeatureKey;
D
danleifeng 已提交
41 42 43 44 45 46
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

// adagrad: embed_sgd_dim=1, embedx_sgd_dim=1,embedx_dim=n
// adam std:  embed_sgd_dim=4, embedx_sgd_dim=n*2+2,embedx_dim=n
// adam shared:  embed_sgd_dim=4, embedx_sgd_dim=4,embedx_dim=n
D
danleifeng 已提交
47
class CommonFeatureValueAccessor {
D
danleifeng 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 public:
  struct CommonFeatureValue {
    /*
      uint64_t cpu_ptr;
      float delta_score;
      float show;
      float click;
      float embed_w;
      std::vector<float> embed_g2sum;
      float slot;
      float mf_dim
      float mf_size
      std::vector<float> embedx_g2sum;
      std::vector<float> embedx_w;
       */

    __host__ __device__ int Dim() {
      return 9 + embed_sgd_dim + embedx_sgd_dim + embedx_dim;
    }  // has cpu_ptr(2)
    __host__ __device__ int DimSize(size_t dim, int embedx_dim) {
      return sizeof(float);
    }
    __host__ __device__ size_t Size() {
      return TYPEALIGN(8, Dim() * sizeof(float));
    }  // cpu_ptr:uint64=2float
L
lxsbupt 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86
    __host__ __device__ int EmbedDim() const { return embed_sgd_dim; }
    __host__ __device__ int EmbedXDim() const { return embedx_sgd_dim; }
    __host__ __device__ int EmbedWDim() const { return embedx_dim; }
    __host__ __device__ int CpuPtrIndex() const { return 0; }  // cpuprt uint64
    __host__ __device__ int DeltaScoreIndex() const {
      return CpuPtrIndex() + 2;
    }
    __host__ __device__ int ShowIndex() const { return DeltaScoreIndex() + 1; }
    __host__ __device__ int ClickIndex() const { return ShowIndex() + 1; }
    __host__ __device__ int EmbedWIndex() const { return ClickIndex() + 1; }
    __host__ __device__ int EmbedG2SumIndex() const {
      return EmbedWIndex() + 1;
    }
    __host__ __device__ int SlotIndex() const {
D
danleifeng 已提交
87 88
      return EmbedG2SumIndex() + embed_sgd_dim;
    }
L
lxsbupt 已提交
89 90
    __host__ __device__ int MfDimIndex() const { return SlotIndex() + 1; }
    __host__ __device__ int MfSizeIndex() const {
D
danleifeng 已提交
91 92
      return MfDimIndex() + 1;
    }  // actual mf size (ex. 0)
L
lxsbupt 已提交
93 94 95 96
    __host__ __device__ int EmbedxG2SumIndex() const {
      return MfSizeIndex() + 1;
    }
    __host__ __device__ int EmbedxWIndex() const {
D
danleifeng 已提交
97 98 99 100
      return EmbedxG2SumIndex() + embedx_sgd_dim;
    }

    // 根据mf_dim计算的总长度
L
lxsbupt 已提交
101
    __host__ __device__ int Dim(int mf_dim) {
D
danleifeng 已提交
102 103 104 105 106 107 108 109 110 111
      int tmp_embedx_sgd_dim = 1;
      if (optimizer_type_ == 3) {  // adam
        tmp_embedx_sgd_dim = mf_dim * 2 + 2;
      } else if (optimizer_type_ == 4) {  // shared_adam
        tmp_embedx_sgd_dim = 4;
      }
      return 9 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim;
    }

    // 根据mf_dim 计算的总byte数
L
lxsbupt 已提交
112
    __host__ __device__ size_t Size(int mf_dim) {
D
danleifeng 已提交
113 114 115 116
      return TYPEALIGN(8, Dim(mf_dim) * sizeof(float));  // cpu_ptr:2float
    }

    // 根据mf_dim 计算的 mf_size byte数
L
lxsbupt 已提交
117
    __host__ __device__ size_t MFSize(int mf_dim) {
D
danleifeng 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130
      int tmp_embedx_sgd_dim = 1;
      if (optimizer_type_ == 3) {  // adam
        tmp_embedx_sgd_dim = mf_dim * 2 + 2;
      } else if (optimizer_type_ == 4) {  // shared_adam
        tmp_embedx_sgd_dim = 4;
      }
      return (tmp_embedx_sgd_dim + mf_dim) * sizeof(float);
    }

    __host__ __device__ int EmbedxG2SumOffsetIndex() { return 0; }
    __host__ __device__ int EmbedxWOffsetIndex(float* val) {
      // has mf
      int tmp_embedx_sgd_dim = 1;
L
lxsbupt 已提交
131
      if (static_cast<int>(MfSize(val)) > 0) {
D
danleifeng 已提交
132
        if (optimizer_type_ == 3) {  // adam
L
lxsbupt 已提交
133
          tmp_embedx_sgd_dim = MfDim(val) * 2 + 2;
D
danleifeng 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        } else if (optimizer_type_ == 4) {  // shared_adam
          tmp_embedx_sgd_dim = 4;
        }
        return EmbedxG2SumIndex() + tmp_embedx_sgd_dim;
      } else {
        // no mf
        return 0;
      }
    }

    __host__ __device__ uint64_t CpuPtr(float* val) {
      return *(reinterpret_cast<uint64_t*>(val));
    }
    __host__ __device__ float& DeltaScore(float* val) {
      return val[DeltaScoreIndex()];
    }
    __host__ __device__ float& Show(float* val) { return val[ShowIndex()]; }
    __host__ __device__ float& Click(float* val) { return val[ClickIndex()]; }
    __host__ __device__ float& Slot(float* val) { return val[SlotIndex()]; }
    __host__ __device__ float& MfDim(float* val) { return val[MfDimIndex()]; }
    __host__ __device__ float& MfSize(float* val) { return val[MfSizeIndex()]; }
    __host__ __device__ float& EmbedW(float* val) { return val[EmbedWIndex()]; }
    __host__ __device__ float& EmbedG2Sum(float* val) {
      return val[EmbedG2SumIndex()];
    }
    __host__ __device__ float& EmbedxG2Sum(float* val) {
      return val[EmbedxG2SumIndex()];
    }
    __host__ __device__ float& EmbedxW(float* val) {
      return val[EmbedxWIndex()];
    }

    int embed_sgd_dim;
    int embedx_dim;
    int embedx_sgd_dim;
    int optimizer_type_;
  };

D
danleifeng 已提交
172 173 174 175 176 177 178 179
  struct CommonPullValue {
    /*
      float show;
      float click;
      float embed_w;
      float mf_size
      std::vector<float> embedx_w;
    */
L
lxsbupt 已提交
180
    __host__ __device__ int Dim(int embedx_dim) { return 4 + embedx_dim; }
D
danleifeng 已提交
181 182
    __host__ __device__ int DimSize(size_t dim) { return sizeof(float); }
    __host__ __device__ int Size(int embedx_dim) {
L
lxsbupt 已提交
183
      return Dim(embedx_dim) * sizeof(float);
D
danleifeng 已提交
184 185 186 187 188 189 190 191 192 193
    }
    __host__ __device__ int ShowIndex() { return 0; }
    __host__ __device__ int ClickIndex() { return 1; }
    __host__ __device__ int EmbedWIndex() { return 2; }
    __host__ __device__ int MfSizeIndex() {
      return 3;
    }  // actual mf size (ex. 0)
    __host__ __device__ int EmbedxWIndex() { return 4; }
  };

D
danleifeng 已提交
194 195 196 197 198 199 200 201 202 203
  struct CommonPushValue {
    /*
       float slot;
       float show;
       float click;
       float mf_dim;
       float embed_g;
       std::vector<float> embedx_g;
       */

L
lxsbupt 已提交
204
    __host__ __device__ int Dim(int embedx_dim) const { return 5 + embedx_dim; }
D
danleifeng 已提交
205

L
lxsbupt 已提交
206
    __host__ __device__ int DimSize(int dim, int embedx_dim) const {
D
danleifeng 已提交
207 208
      return sizeof(float);
    }
L
lxsbupt 已提交
209 210
    __host__ __device__ int Size(int embedx_dim) const {
      return Dim(embedx_dim) * sizeof(float);
D
danleifeng 已提交
211
    }
L
lxsbupt 已提交
212 213
    __host__ __device__ int SlotIndex() const { return 0; }
    __host__ __device__ int ShowIndex() const {
D
danleifeng 已提交
214 215
      return CommonPushValue::SlotIndex() + 1;
    }
L
lxsbupt 已提交
216
    __host__ __device__ int ClickIndex() const {
D
danleifeng 已提交
217 218
      return CommonPushValue::ShowIndex() + 1;
    }
L
lxsbupt 已提交
219
    __host__ __device__ int MfDimIndex() const {
D
danleifeng 已提交
220 221
      return CommonPushValue::ClickIndex() + 1;
    }
L
lxsbupt 已提交
222
    __host__ __device__ int EmbedGIndex() const {
D
danleifeng 已提交
223 224
      return CommonPushValue::MfDimIndex() + 1;
    }
L
lxsbupt 已提交
225
    __host__ __device__ int EmbedxGIndex() const {
D
danleifeng 已提交
226 227
      return CommonPushValue::EmbedGIndex() + 1;
    }
L
lxsbupt 已提交
228
    __host__ __device__ float& Slot(float* val) const {
D
danleifeng 已提交
229 230
      return val[CommonPushValue::SlotIndex()];
    }
L
lxsbupt 已提交
231
    __host__ __device__ float& Show(float* val) const {
D
danleifeng 已提交
232 233
      return val[CommonPushValue::ShowIndex()];
    }
L
lxsbupt 已提交
234
    __host__ __device__ float& Click(float* val) const {
D
danleifeng 已提交
235 236
      return val[CommonPushValue::ClickIndex()];
    }
L
lxsbupt 已提交
237
    __host__ __device__ float& MfDim(float* val) const {
D
danleifeng 已提交
238 239
      return val[CommonPushValue::MfDimIndex()];
    }
L
lxsbupt 已提交
240
    __host__ __device__ float& EmbedG(float* val) const {
D
danleifeng 已提交
241 242
      return val[CommonPushValue::EmbedGIndex()];
    }
L
lxsbupt 已提交
243
    __host__ __device__ float* EmbedxG(float* val) const {
D
danleifeng 已提交
244 245 246 247 248 249 250
      return val + CommonPushValue::EmbedxGIndex();
    }
  };

  __host__ __device__ CommonFeatureValueAccessor() {}
  __host__ __device__ ~CommonFeatureValueAccessor() {}

D
danleifeng 已提交
251
  __host__ int Initialize() {
D
danleifeng 已提交
252 253
    int optimizer_type = (_config.find("optimizer_type") == _config.end())
                             ? 1
L
lxsbupt 已提交
254
                             : _config["optimizer_type"];
D
danleifeng 已提交
255 256
    int sparse_embedx_dim = (_config.find("embedx_dim") == _config.end())
                                ? 8
L
lxsbupt 已提交
257
                                : _config["embedx_dim"];
D
danleifeng 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    if (optimizer_type == 3) {  // adam
      common_feature_value.embed_sgd_dim = 4;
      common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2;
    } else if (optimizer_type == 4) {  // shared_adam
      common_feature_value.embed_sgd_dim = 4;
      common_feature_value.embedx_sgd_dim = 4;
    } else {
      common_feature_value.embed_sgd_dim = 1;
      common_feature_value.embedx_sgd_dim = 1;
    }
    common_feature_value.optimizer_type_ = optimizer_type;
    common_feature_value.embedx_dim = sparse_embedx_dim;

    return 0;
  }

L
lxsbupt 已提交
274
  __host__ int Configure(const std::unordered_map<std::string, float>& config) {
D
danleifeng 已提交
275 276 277 278 279
    _config = config;
    Initialize();
    return 0;
  }

P
pangengzheng 已提交
280
#ifdef PADDLE_WITH_PSCORE
D
danleifeng 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
  // // build阶段从cpu_val赋值给gpu_val
  __host__ void BuildFill(
      float* gpu_val,
      void* cpu,
      paddle::distributed::ValueAccessor* cpu_table_accessor,
      int mf_dim) {
    paddle::distributed::CtrDymfAccessor* cpu_accessor =
        dynamic_cast<paddle::distributed::CtrDymfAccessor*>(cpu_table_accessor);
    paddle::distributed::FixedFeatureValue* cpu_ptr =
        (paddle::distributed::FixedFeatureValue*)(cpu);
    float* cpu_val = cpu_ptr->data();
    size_t cpu_dim = cpu_ptr->size();

    gpu_val[common_feature_value.DeltaScoreIndex()] =
        cpu_val[cpu_accessor->common_feature_value.DeltaScoreIndex()];
    gpu_val[common_feature_value.ShowIndex()] =
        cpu_val[cpu_accessor->common_feature_value.ShowIndex()];
    gpu_val[common_feature_value.ClickIndex()] =
        cpu_val[cpu_accessor->common_feature_value.ClickIndex()];
    gpu_val[common_feature_value.SlotIndex()] =
        cpu_val[cpu_accessor->common_feature_value.SlotIndex()];
    gpu_val[common_feature_value.EmbedWIndex()] =
        cpu_val[cpu_accessor->common_feature_value.EmbedWIndex()];
    for (int i = 0; i < common_feature_value.EmbedDim(); i++) {
      gpu_val[common_feature_value.EmbedG2SumIndex() + i] =
          cpu_val[cpu_accessor->common_feature_value.EmbedG2SumIndex() + i];
    }
    *(reinterpret_cast<uint64_t*>(
        gpu_val + common_feature_value.CpuPtrIndex())) = (uint64_t)(cpu);
L
lxsbupt 已提交
310 311
    cpu_val[cpu_accessor->common_feature_value.MfDimIndex()] =
        static_cast<float>(mf_dim);
D
danleifeng 已提交
312 313 314 315 316 317
    gpu_val[common_feature_value.MfDimIndex()] = mf_dim;
    if (cpu_dim > cpu_accessor->GetAccessorInfo().dim -
                      cpu_accessor->GetAccessorInfo().mf_size / sizeof(float)) {
      gpu_val[common_feature_value.MfSizeIndex()] =
          common_feature_value.MFSize(mf_dim) / sizeof(float);

L
lxsbupt 已提交
318 319
      for (size_t x = 0;
           x < (common_feature_value.MFSize(mf_dim) / sizeof(float));
D
danleifeng 已提交
320 321 322 323 324 325
           x++) {
        gpu_val[common_feature_value.EmbedxG2SumIndex() + x] =
            cpu_val[cpu_accessor->common_feature_value.EmbedxG2SumIndex() + x];
      }
    } else {
      gpu_val[common_feature_value.MfSizeIndex()] = 0;
L
lxsbupt 已提交
326 327
      for (size_t x = common_feature_value.EmbedxG2SumIndex();
           x < (common_feature_value.Size(mf_dim) / sizeof(float));
D
danleifeng 已提交
328 329 330 331
           x++) {
        gpu_val[x] = 0;
      }
    }
P
pangengzheng 已提交
332
  }
D
danleifeng 已提交
333
#endif
P
pangengzheng 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393

#ifdef PADDLE_WITH_PSLIB
  // build阶段从cpu_val赋值给gpu_val
  template <typename ShowClickType>
  __host__ void BuildFill(float* gpu_val,
                          void* _cpu_val,
                          ::paddle::ps::ValueAccessor* _cpu_accessor,
                          int mf_dim) {
    auto* cpu_accessor =
        dynamic_cast<::paddle::ps::DownpourCtrDymfTplAccessor<ShowClickType>*>(
            _cpu_accessor);
    auto* cpu_val =
        reinterpret_cast<::paddle::ps::DownpourFixedFeatureValue*>(_cpu_val);
    float* ptr_val = cpu_val->data();
    size_t cpu_dim = cpu_val->size();

    gpu_val[common_feature_value.DeltaScoreIndex()] =
        ptr_val[cpu_accessor->get_delta_score_index()];
    gpu_val[common_feature_value.ShowIndex()] = cpu_accessor->get_show(ptr_val);
    gpu_val[common_feature_value.ClickIndex()] =
        cpu_accessor->get_click(ptr_val);

    gpu_val[common_feature_value.SlotIndex()] =
        ptr_val[cpu_accessor->get_slot_index()];

    // lr
    gpu_val[common_feature_value.EmbedWIndex()] =
        ptr_val[cpu_accessor->get_embed_w_index()];

    // cpu_ptr
    *(reinterpret_cast<uint64_t*>(
        gpu_val + common_feature_value.CpuPtrIndex())) = (uint64_t)(cpu_val);

    // lr_g2sum
    // for dymf && adagrad, embed_dim = 1
    for (int i = 0; i < common_feature_value.EmbedDim(); i++) {
      gpu_val[common_feature_value.EmbedG2SumIndex() + i] =
          ptr_val[cpu_accessor->get_embed_g2sum_index() + i];
    }

    ptr_val[cpu_accessor->get_mf_dim_index()] = static_cast<float>(mf_dim);
    gpu_val[common_feature_value.MfDimIndex()] = static_cast<float>(mf_dim);
    constexpr int n = 2 * (sizeof(ShowClickType) / sizeof(float) - 1);

    if (cpu_dim > 8 + n) {
      gpu_val[common_feature_value.MfSizeIndex()] =
          common_feature_value.MFSize(mf_dim) / sizeof(float);

      for (int x = 0; x < static_cast<int>(common_feature_value.MFSize(mf_dim) /
                                           sizeof(float));
           x++) {
        gpu_val[common_feature_value.EmbedxG2SumIndex() + x] =
            ptr_val[8 + n + x];
      }
    } else {
      gpu_val[common_feature_value.MfSizeIndex()] = 0;
      for (int i = 0; i < mf_dim + common_feature_value.EmbedXDim(); i++) {
        gpu_val[common_feature_value.EmbedxG2SumIndex() + i] = 0;
      }
    }
D
danleifeng 已提交
394
  }
P
pangengzheng 已提交
395
#endif
D
danleifeng 已提交
396

P
pangengzheng 已提交
397
#ifdef PADDLE_WITH_PSCORE
D
danleifeng 已提交
398 399 400 401 402 403 404 405 406 407 408 409
  // dump_to_cpu阶段从gpu_val赋值给cpu_val
  __host__ void DumpFill(float* gpu_val,
                         paddle::distributed::ValueAccessor* cpu_table_accessor,
                         int mf_dim) {
    paddle::distributed::CtrDymfAccessor* cpu_accessor =
        dynamic_cast<paddle::distributed::CtrDymfAccessor*>(cpu_table_accessor);

    auto* downpour_value =
        (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast<uint64_t*>(
            gpu_val + common_feature_value.CpuPtrIndex())));
    size_t downpour_value_size = downpour_value->size();
    if (gpu_val[common_feature_value.MfSizeIndex()] > 0 &&
L
lxsbupt 已提交
410 411 412 413
        downpour_value_size ==
            (cpu_accessor->GetAccessorInfo().dim -
             static_cast<int>(cpu_accessor->GetAccessorInfo().mf_size /
                              sizeof(float)))) {  // cpu_accessor
D
danleifeng 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
      downpour_value->resize(cpu_accessor->common_feature_value.Dim(mf_dim));
    }
    float* cpu_val = downpour_value->data();
    cpu_val[cpu_accessor->common_feature_value.DeltaScoreIndex()] =
        gpu_val[common_feature_value.DeltaScoreIndex()];
    cpu_val[cpu_accessor->common_feature_value.ShowIndex()] =
        gpu_val[common_feature_value.ShowIndex()];
    cpu_val[cpu_accessor->common_feature_value.ClickIndex()] =
        gpu_val[common_feature_value.ClickIndex()];
    cpu_val[cpu_accessor->common_feature_value.EmbedWIndex()] =
        gpu_val[common_feature_value.EmbedWIndex()];
    cpu_val[cpu_accessor->common_feature_value.SlotIndex()] =
        gpu_val[common_feature_value.SlotIndex()];

    for (int i = 0; i < common_feature_value.EmbedDim(); i++) {
      cpu_val[cpu_accessor->common_feature_value.EmbedG2SumIndex() + i] =
          gpu_val[common_feature_value.EmbedG2SumIndex() + i];
    }

    if (gpu_val[common_feature_value.MfSizeIndex()] > 0) {
L
lxsbupt 已提交
434 435
      for (size_t x = 0;
           x < (common_feature_value.MFSize(mf_dim) / sizeof(float));
D
danleifeng 已提交
436 437 438 439 440
           x++) {
        cpu_val[cpu_accessor->common_feature_value.EmbedxG2SumIndex() + x] =
            gpu_val[common_feature_value.EmbedxG2SumIndex() + x];
      }
    }
P
pangengzheng 已提交
441
  }
D
danleifeng 已提交
442
#endif
P
pangengzheng 已提交
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496

#ifdef PADDLE_WITH_PSLIB
  // dump_to_cpu阶段从gpu_val赋值给cpu_val
  // gpu_val is firstly copied to mem
  // so gpu_val is in mem, not in hbm
  template <typename ShowClickType>
  __host__ void DumpFill(float* gpu_val,
                         ::paddle::ps::ValueAccessor* _cpu_accessor,
                         int mf_dim) {
    auto* cpu_accessor =
        dynamic_cast<::paddle::ps::DownpourCtrDymfTplAccessor<ShowClickType>*>(
            _cpu_accessor);
    uint64_t cpu_addr = *reinterpret_cast<uint64_t*>(
        gpu_val + common_feature_value.CpuPtrIndex());
    auto* downpour_value = (::paddle::ps::DownpourFixedFeatureValue*)cpu_addr;
    int downpour_value_size = downpour_value->size();
    constexpr int n = 2 * (sizeof(ShowClickType) / sizeof(float) - 1);
    if (static_cast<int>(gpu_val[common_feature_value.MfSizeIndex()]) > 0 &&
        downpour_value_size == 8 + n) {
      int mf_size =
          common_feature_value.MFSize(mf_dim) /
          sizeof(
              float);  // mf_size = gpu_val[common_feature_value.MfSizeIndex()];
      downpour_value->resize(downpour_value_size + mf_size);
    }
    float* cpu_val = downpour_value->data();

    cpu_val[cpu_accessor->get_delta_score_index()] =
        gpu_val[common_feature_value.DeltaScoreIndex()];
    *reinterpret_cast<ShowClickType*>(cpu_val +
                                      cpu_accessor->get_show_index()) =
        (ShowClickType)gpu_val[common_feature_value.ShowIndex()];
    *reinterpret_cast<ShowClickType*>(cpu_val +
                                      cpu_accessor->get_click_index()) =
        (ShowClickType)gpu_val[common_feature_value.ClickIndex()];
    cpu_val[cpu_accessor->get_embed_w_index()] =
        gpu_val[common_feature_value.EmbedWIndex()];
    cpu_val[cpu_accessor->get_slot_index()] =
        gpu_val[common_feature_value.SlotIndex()];

    // for dymf && adagrad, embed_dim = 1
    for (int i = 0; i < common_feature_value.EmbedDim(); i++) {
      cpu_val[cpu_accessor->get_embed_g2sum_index() + i] =
          gpu_val[common_feature_value.EmbedG2SumIndex() + i];
    }

    if (static_cast<int>(gpu_val[common_feature_value.MfSizeIndex()]) > 0) {
      for (int x = 0; x < static_cast<int>(common_feature_value.MFSize(mf_dim) /
                                           sizeof(float));
           x++) {
        cpu_val[x + 8 + n] =
            gpu_val[common_feature_value.EmbedxG2SumIndex() + x];
      }
    }
D
danleifeng 已提交
497
  }
P
pangengzheng 已提交
498
#endif
D
danleifeng 已提交
499

D
danleifeng 已提交
500
  // dy_mf_fill_dvals_kernel 阶段 gpukernel
D
danleifeng 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
  // 中从src_val赋值给dest_val
  __host__ __device__ void FeatureValueFill(float* dest_val,
                                            float* src_val,
                                            int mf_dim) {
    *(reinterpret_cast<uint64_t*>(dest_val +
                                  common_feature_value.CpuPtrIndex())) =
        *(reinterpret_cast<uint64_t*>(src_val +
                                      common_feature_value.CpuPtrIndex()));
    dest_val[common_feature_value.DeltaScoreIndex()] =
        src_val[common_feature_value.DeltaScoreIndex()];
    dest_val[common_feature_value.ShowIndex()] =
        src_val[common_feature_value.ShowIndex()];
    dest_val[common_feature_value.ClickIndex()] =
        src_val[common_feature_value.ClickIndex()];
    dest_val[common_feature_value.EmbedWIndex()] =
        src_val[common_feature_value.EmbedWIndex()];
    for (int i = 0; i < common_feature_value.EmbedDim(); i++) {
      dest_val[common_feature_value.EmbedG2SumIndex() + i] =
          src_val[common_feature_value.EmbedG2SumIndex() + i];
    }
    dest_val[common_feature_value.SlotIndex()] =
        src_val[common_feature_value.SlotIndex()];
    dest_val[common_feature_value.MfDimIndex()] = mf_dim;
    dest_val[common_feature_value.MfSizeIndex()] =
        src_val[common_feature_value.MfSizeIndex()];

L
lxsbupt 已提交
527 528
    for (size_t x = common_feature_value.EmbedxG2SumIndex();
         x < (common_feature_value.Size(mf_dim) / sizeof(float));
D
danleifeng 已提交
529 530 531 532 533
         x++) {
      dest_val[x] = src_val[x];
    }
  }

D
danleifeng 已提交
534 535 536 537 538 539 540 541 542 543
  // dy_mf_fill_dvals_kernel, dy_mf_search_kernel 阶段 gpukernel
  // 中从src_val赋值给dest_val
  __host__ __device__ void PullValueFill(float* dest_val, float* src_val) {
    dest_val[common_pull_value.ShowIndex()] =
        src_val[common_feature_value.ShowIndex()];
    dest_val[common_pull_value.ClickIndex()] =
        src_val[common_feature_value.ClickIndex()];
    dest_val[common_pull_value.EmbedWIndex()] =
        src_val[common_feature_value.EmbedWIndex()];

L
lxsbupt 已提交
544
    int mf_size = static_cast<int>(src_val[common_feature_value.MfSizeIndex()]);
D
danleifeng 已提交
545 546 547 548 549
    if (mf_size == 0) {
      dest_val[common_pull_value.MfSizeIndex()] = 0;
      return;
    }
    // set pull value real dim size
L
lxsbupt 已提交
550
    int mf_dim = static_cast<int>(src_val[common_feature_value.MfDimIndex()]);
D
danleifeng 已提交
551
    dest_val[common_pull_value.MfSizeIndex()] = mf_dim;
L
lxsbupt 已提交
552 553 554 555 556
    // check
    if (mf_dim > mf_size) {
      printf("mf_dim[%d] <= mf_size[%d]", mf_dim, mf_size);
      return;
    }
D
danleifeng 已提交
557 558 559 560 561 562 563

    int embedx_off = common_pull_value.EmbedxWIndex();
    int value_off = common_feature_value.EmbedxWIndex();
    for (int k = 0; k < mf_dim; ++k) {
      dest_val[embedx_off + k] = src_val[value_off + k];
    }
  }
L
lxsbupt 已提交
564 565 566 567 568 569 570
  // set zero value by infer
  __host__ __device__ void PullZeroValue(float* dest_val) {
    dest_val[common_pull_value.ShowIndex()] = 0.0;
    dest_val[common_pull_value.ClickIndex()] = 0.0;
    dest_val[common_pull_value.EmbedWIndex()] = 0.0;
    dest_val[common_pull_value.MfSizeIndex()] = 0;
  }
D
danleifeng 已提交
571

D
danleifeng 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
  // dy_mf_fill_shard_grads_kernel,update_one 阶段 gpukernel
  // 中从src_val赋值给dest_val
  __host__ __device__ void PushValueFill(float* dest_val,
                                         const float* src_val) {
    dest_val[common_push_value.SlotIndex()] =
        src_val[common_push_value.SlotIndex()];
    dest_val[common_push_value.ShowIndex()] =
        src_val[common_push_value.ShowIndex()];
    dest_val[common_push_value.ClickIndex()] =
        src_val[common_push_value.ClickIndex()];
    dest_val[common_push_value.MfDimIndex()] =
        src_val[common_push_value.MfDimIndex()];
    dest_val[common_push_value.EmbedGIndex()] =
        src_val[common_push_value.EmbedGIndex()];

L
lxsbupt 已提交
587
    for (size_t x = 0; x < (src_val[common_push_value.MfDimIndex()]); x++) {
D
danleifeng 已提交
588 589 590 591 592 593 594
      dest_val[common_push_value.EmbedxGIndex() + x] =
          src_val[common_push_value.EmbedxGIndex() + x];
    }
  }

  // update_basic 阶段 gpukernel 中从src_val赋值给dest_val
  __host__ __device__ void PushValueFillBasic(float* dest_val,
L
lxsbupt 已提交
595
                                              const float* src_val) const {
D
danleifeng 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
    dest_val[common_push_value.SlotIndex()] =
        src_val[common_push_value.SlotIndex()];
    dest_val[common_push_value.ShowIndex()] =
        src_val[common_push_value.ShowIndex()];
    dest_val[common_push_value.ClickIndex()] =
        src_val[common_push_value.ClickIndex()];
    dest_val[common_push_value.MfDimIndex()] =
        src_val[common_push_value.MfDimIndex()];
    dest_val[common_push_value.EmbedGIndex()] =
        src_val[common_push_value.EmbedGIndex()];
  }

  // merge_one 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val
  __host__ __device__ void MergePushValue(float* dest_val,
                                          const float* src_val) {
    dest_val[common_push_value.ShowIndex()] +=
        src_val[common_push_value.ShowIndex()];
    dest_val[common_push_value.ClickIndex()] +=
        src_val[common_push_value.ClickIndex()];
    dest_val[common_push_value.EmbedGIndex()] +=
        src_val[common_push_value.EmbedGIndex()];
L
lxsbupt 已提交
617
    for (size_t j = 0; j < (dest_val[common_push_value.MfDimIndex()]); j++) {
D
danleifeng 已提交
618 619 620 621 622 623 624
      dest_val[common_push_value.EmbedxGIndex() + j] +=
          src_val[common_push_value.EmbedxGIndex() + j];
    }
  }

  // merge_basic 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val
  __host__ __device__ void MergePushValueBasic(float* dest_val,
L
lxsbupt 已提交
625
                                               const float* src_val) const {
D
danleifeng 已提交
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
    dest_val[common_push_value.ShowIndex()] +=
        src_val[common_push_value.ShowIndex()];
    dest_val[common_push_value.ClickIndex()] +=
        src_val[common_push_value.ClickIndex()];
    dest_val[common_push_value.EmbedGIndex()] +=
        src_val[common_push_value.EmbedGIndex()];
  }

  // PullCopy 阶段 gpukernel 中  FeatureValue回填到PullValue
  __host__ __device__ void Select(float* dest_val,
                                  float* src_val,
                                  uint64_t* key,
                                  int mf_dim) {
    if (*key == 0) {
      *(dest_val + common_pull_value.ShowIndex()) = 0;
      *(dest_val + common_pull_value.ClickIndex()) = 0;
      *(dest_val + common_pull_value.EmbedWIndex()) = 0;
    } else {
      *(dest_val + common_pull_value.ShowIndex()) =
          src_val[common_feature_value.ShowIndex()];
      *(dest_val + common_pull_value.ClickIndex()) =
          src_val[common_feature_value.ClickIndex()];
      *(dest_val + common_pull_value.EmbedWIndex()) =
          src_val[common_feature_value.EmbedWIndex()];
    }
L
lxsbupt 已提交
651 652
    int mf_size = static_cast<int>(src_val[common_feature_value.MfSizeIndex()]);
    if (mf_size == 0 || *key == 0) {
D
danleifeng 已提交
653
      for (int j = 0; j < mf_dim; j++) {
L
lxsbupt 已提交
654
        *(dest_val + 3 + j) = 0;
D
danleifeng 已提交
655 656 657
      }
    } else {
      for (int j = 0; j < mf_dim; j++) {
D
danleifeng 已提交
658 659 660
        // common_pull_value EmbedxWIndex 之前还有 MfSizeIndex,
        // 所以这里没有直接使用 common_pull_value.EmbedxWIndex()
        *(dest_val + 3 + j) = src_val[common_pull_value.EmbedxWIndex() + j];
D
danleifeng 已提交
661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
      }
    }
  }

  __host__ __device__ std::string ParseToString(const float* v,
                                                int param_size) {
    /*
        uint64_t cpu_ptr; // 2float
        float delta_score;
        float show;
        float click;
        float embed_w;
        std::vector<float> embed_g2sum;
        float slot;
        float mf_dim
        float mf_size
        std::vector<float> embedx_g2sum;
        std::vector<float> embedx_w;
    */
    std::stringstream os;
    os << "cpuptr: " << common_feature_value.CpuPtr(const_cast<float*>(v))
       << " delta_score: " << v[2] << " show: " << v[3] << " click: " << v[4]
       << " embed_w:" << v[5] << " embed_g2sum:";
    for (int i = common_feature_value.EmbedG2SumIndex();
         i < common_feature_value.SlotIndex();
         i++) {
      os << " " << v[i];
    }
L
lxsbupt 已提交
689 690
    int mf_dim =
        static_cast<int>(common_feature_value.MfDim(const_cast<float*>(v)));
D
danleifeng 已提交
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705
    os << " slot: " << common_feature_value.Slot(const_cast<float*>(v))
       << " mf_dim: " << mf_dim
       << " mf_size: " << common_feature_value.MfSize(const_cast<float*>(v))
       << " mf: ";
    if (param_size > common_feature_value.EmbedxG2SumIndex()) {
      for (auto i = common_feature_value.EmbedxG2SumIndex();
           i < common_feature_value.Dim(mf_dim);
           ++i) {
        os << " " << v[i];
      }
    }
    return os.str();
  }

 public:
D
danleifeng 已提交
706
  std::unordered_map<std::string, float> _config;
D
danleifeng 已提交
707 708 709 710
  CommonFeatureValue common_feature_value;
  CommonPushValue common_push_value;
  CommonPullValue common_pull_value;
};
T
Thunderbrook 已提交
711 712 713 714 715 716 717 718 719

struct FeatureValue {
  float delta_score;
  float show;
  float clk;
  int slot;
  float lr;
  float lr_g2sum;
  int mf_size;
Y
yaoxuefeng 已提交
720
  int mf_dim;
T
Thunderbrook 已提交
721
  uint64_t cpu_ptr;
Y
yaoxuefeng 已提交
722
  float mf[0];
T
Thunderbrook 已提交
723 724 725

  friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
    out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
Y
yaoxuefeng 已提交
726 727 728
        << " lr: " << val.lr << " mf_dim: " << val.mf_dim
        << "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:";
    for (int i = 0; i < val.mf_dim + 1; ++i) {
T
Thunderbrook 已提交
729 730 731 732
      out << " " << val.mf[i];
    }
    return out;
  }
Y
yaoxuefeng 已提交
733 734 735 736 737 738 739 740 741 742 743 744 745 746
  __device__ __forceinline__ void operator=(const FeatureValue& in) {
    delta_score = in.delta_score;
    show = in.show;
    clk = in.clk;
    slot = in.slot;
    lr = in.lr;
    lr_g2sum = in.lr_g2sum;
    mf_size = in.mf_size;
    mf_dim = in.mf_dim;
    cpu_ptr = in.cpu_ptr;
    for (int i = 0; i < mf_dim + 1; i++) {
      mf[i] = in.mf[i];
    }
  }
T
Thunderbrook 已提交
747 748 749 750 751 752 753
};

struct FeaturePushValue {
  float show;
  float clk;
  int slot;
  float lr_g;
Y
yaoxuefeng 已提交
754 755
  int mf_dim;
  float mf_g[0];
Y
yaoxuefeng 已提交
756

Y
yaoxuefeng 已提交
757 758 759 760 761 762 763 764 765 766 767 768 769 770
  __device__ __forceinline__ FeaturePushValue
  operator+(const FeaturePushValue& a) const {
    FeaturePushValue out;
    out.slot = a.slot;
    out.mf_dim = a.mf_dim;
    out.show = a.show + show;
    out.clk = a.clk + clk;
    out.lr_g = a.lr_g + lr_g;
    // out.mf_g = a.mf_g;
    for (int i = 0; i < out.mf_dim; ++i) {
      out.mf_g[i] = a.mf_g[i] + mf_g[i];
    }
    return out;
  }
Y
yaoxuefeng 已提交
771 772 773 774 775 776 777 778 779 780
  __device__ __forceinline__ void operator=(const FeaturePushValue& in) {
    show = in.show;
    clk = in.clk;
    slot = in.slot;
    lr_g = in.lr_g;
    mf_dim = in.mf_dim;
    for (int i = 0; i < mf_dim; i++) {
      mf_g[i] = in.mf_g[i];
    }
  }
T
Thunderbrook 已提交
781 782
};

D
danleifeng 已提交
783 784 785 786
class VirtualAccessor {
 public:
  virtual int Configure(std::unordered_map<std::string, float> config) = 0;

L
lxsbupt 已提交
787
  virtual size_t GetFeatureValueSize(int& mf_dim) = 0;  // NOLINT
D
danleifeng 已提交
788

L
lxsbupt 已提交
789
  virtual size_t GetPushValueSize(int& mf_dim) = 0;  // NOLINT
D
danleifeng 已提交
790

L
lxsbupt 已提交
791
  virtual size_t GetPullValueSize(int& mf_dim) = 0;  // NOLINT
D
danleifeng 已提交
792

P
pangengzheng 已提交
793
#ifdef PADDLE_WITH_PSCORE
D
danleifeng 已提交
794 795 796 797 798 799 800 801
  virtual void BuildFill(void* gpu_val,
                         void* cpu_val,
                         paddle::distributed::ValueAccessor* cpu_table_accessor,
                         int mf_dim) = 0;

  virtual void DumpFill(float* gpu_val,
                        paddle::distributed::ValueAccessor* cpu_table_accessor,
                        int mf_dim) = 0;
P
pangengzheng 已提交
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
#endif

#ifdef PADDLE_WITH_PSLIB
  virtual void BuildFill(
      float* gpu_val,
      void* cpu_val,
      paddle::ps::ValueAccessor* cpu_accessor,
      int mf_dim,
      const std::string& accessor_type = "DownpourCtrDymfAccessor") = 0;

  virtual void DumpFill(
      float* gpu_val,
      ::paddle::ps::ValueAccessor* cpu_accessor,
      int mf_dim,
      const std::string& accessor_type = "DownpourCtrDymfAccessor") = 0;
#endif
D
danleifeng 已提交
818 819 820 821 822 823 824 825 826 827 828

  virtual void CopyForPull(const paddle::platform::Place& place,
                           uint64_t** gpu_keys,
                           const std::vector<float*>& values,
                           const float* total_values_gpu,
                           const int64_t* gpu_len,
                           const int slot_num,
                           const int hidden_size,
                           const int64_t total_length,
                           int* gpu_dim,
                           int feature_value_size) = 0;
D
danleifeng 已提交
829 830 831 832 833 834 835 836 837 838 839 840
  // dedup
  virtual void CopyForPull(const paddle::platform::Place& place,
                           const uint64_t* total_keys,
                           float** gpu_values,
                           const float* total_values_gpu,
                           const int64_t* slot_lens,
                           const int* key2slot,
                           const int hidden_size,
                           const int64_t total_length,
                           const int* slot_dims,
                           const uint32_t* gpu_restore_idx,
                           int pull_value_size) = 0;
D
danleifeng 已提交
841 842 843 844 845 846 847 848

  virtual void CopyForPush(const paddle::platform::Place& place,
                           const std::vector<const float*>& grad_values,
                           float* total_grad_values_gpu,
                           const std::vector<int64_t>& slot_lengths,
                           const uint64_t total_length,
                           const int batch_size,
                           size_t grad_value_size,
L
lxsbupt 已提交
849 850
                           std::vector<int>& slot_vector,              // NOLINT
                           std::vector<int>& slot_mf_dim_vector) = 0;  // NOLINT
D
danleifeng 已提交
851

D
danleifeng 已提交
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
  // dedup
  virtual void CopyForPush(const paddle::platform::Place& place,
                           const uint64_t* total_keys,
                           float** grad_values,
                           float* total_grad_values_gpu,
                           const int* slots,
                           const int64_t* slot_lens,
                           const int hidden_size,
                           const int64_t total_length,
                           const int64_t dedup_length,
                           const int batch_size,
                           const int* slot_dims,
                           const int* key2slot,
                           const uint32_t* d_restore_idx,
                           const size_t grad_value_size) = 0;

  virtual void CopyForPush(const paddle::platform::Place& place,
                           const uint64_t* total_keys,
                           float** grad_values,
                           float* total_grad_values_gpu,
                           const int* slots,
                           const int64_t* slot_lens,
                           const int hidden_size,
                           const int64_t total_length,
                           const int64_t dedup_length,
                           const int batch_size,
                           const int* slot_dims,
                           const int* key2slot,
                           const uint32_t* gpu_sort_idx,
                           const uint32_t* gpu_sort_offset,
                           const uint32_t* gpu_sort_lens,
                           const size_t grad_value_size) = 0;

D
danleifeng 已提交
885 886 887 888 889 890
  virtual std::string ParseToString(const float* v, int param_size) = 0;
};

template <typename GPUAccessor>
class AccessorWrapper : public VirtualAccessor {
 public:
L
lxsbupt 已提交
891
  AccessorWrapper() {}
D
danleifeng 已提交
892 893 894 895 896 897 898 899
  virtual ~AccessorWrapper() {}
  AccessorWrapper(const AccessorWrapper&) = delete;
  AccessorWrapper& operator=(const AccessorWrapper&) = delete;

  virtual int Configure(std::unordered_map<std::string, float> config) {
    return gpu_accessor_.Configure(config);
  }

L
lxsbupt 已提交
900
  virtual size_t GetFeatureValueSize(int& mf_dim) {  // NOLINT
D
danleifeng 已提交
901 902 903
    return gpu_accessor_.common_feature_value.Size(mf_dim);
  }

L
lxsbupt 已提交
904
  virtual size_t GetPushValueSize(int& mf_dim) {  // NOLINT
D
danleifeng 已提交
905 906 907
    return gpu_accessor_.common_push_value.Size(mf_dim);
  }

L
lxsbupt 已提交
908
  virtual size_t GetPullValueSize(int& mf_dim) {  // NOLINT
D
danleifeng 已提交
909 910 911 912 913
    return gpu_accessor_.common_pull_value.Size(mf_dim);
  }

  GPUAccessor* AccessorPtr() { return &gpu_accessor_; }

P
pangengzheng 已提交
914
#ifdef PADDLE_WITH_PSCORE
D
danleifeng 已提交
915 916 917 918 919
  virtual void BuildFill(void* gpu_val,
                         void* cpu_val,
                         paddle::distributed::ValueAccessor* cpu_table_accessor,
                         int mf_dim) {
    gpu_accessor_.BuildFill(
L
lxsbupt 已提交
920
        reinterpret_cast<float*>(gpu_val), cpu_val, cpu_table_accessor, mf_dim);
D
danleifeng 已提交
921 922 923 924 925 926 927
  }

  virtual void DumpFill(float* gpu_val,
                        paddle::distributed::ValueAccessor* cpu_table_accessor,
                        int mf_dim) {
    gpu_accessor_.DumpFill(gpu_val, cpu_table_accessor, mf_dim);
  }
P
pangengzheng 已提交
928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957
#endif

#ifdef PADDLE_WITH_PSLIB
  virtual void BuildFill(
      float* gpu_val,
      void* cpu_val,
      paddle::ps::ValueAccessor* cpu_accessor,
      int mf_dim,
      const std::string& accessor_type = "DownpourCtrDymfAccessor") {
    if (accessor_type == "DownpourCtrDymfAccessor") {
      gpu_accessor_.template BuildFill<float>(
          gpu_val, cpu_val, cpu_accessor, mf_dim);
    } else if (accessor_type == "DownpourCtrDoubleDymfAccessor") {
      gpu_accessor_.template BuildFill<double>(
          gpu_val, cpu_val, cpu_accessor, mf_dim);
    }
  }

  virtual void DumpFill(
      float* gpu_val,
      paddle::ps::ValueAccessor* cpu_accessor,
      int mf_dim,
      const std::string& accessor_type = "DownpourCtrDymfAccessor") {
    if (accessor_type == "DownpourCtrDymfAccessor") {
      gpu_accessor_.template DumpFill<float>(gpu_val, cpu_accessor, mf_dim);
    } else if (accessor_type == "DownpourCtrDoubleDymfAccessor") {
      gpu_accessor_.template DumpFill<double>(gpu_val, cpu_accessor, mf_dim);
    }
  }
#endif
D
danleifeng 已提交
958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980

  virtual void CopyForPull(const paddle::platform::Place& place,
                           uint64_t** gpu_keys,
                           const std::vector<float*>& values,
                           const float* total_values_gpu,
                           const int64_t* gpu_len,
                           const int slot_num,
                           const int hidden_size,
                           const int64_t total_length,
                           int* gpu_dim,
                           int feature_value_size) {
    CopyForPullImpl(place,
                    gpu_keys,
                    values,
                    total_values_gpu,
                    gpu_len,
                    slot_num,
                    hidden_size,
                    total_length,
                    gpu_dim,
                    feature_value_size);
  }

D
danleifeng 已提交
981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004
  virtual void CopyForPull(const paddle::platform::Place& place,
                           const uint64_t* total_keys,
                           float** gpu_values,
                           const float* total_values_gpu,
                           const int64_t* slot_lens,
                           const int* key2slot,
                           const int hidden_size,
                           const int64_t total_length,
                           const int* slot_dims,
                           const uint32_t* gpu_restore_idx,
                           int pull_value_size) {
    CopyForPullDedupImpl(place,
                         total_keys,
                         gpu_values,
                         total_values_gpu,
                         slot_lens,
                         key2slot,
                         hidden_size,
                         total_length,
                         slot_dims,
                         gpu_restore_idx,
                         pull_value_size);
  }

D
danleifeng 已提交
1005 1006 1007 1008 1009 1010 1011
  virtual void CopyForPush(const paddle::platform::Place& place,
                           const std::vector<const float*>& grad_values,
                           float* total_grad_values_gpu,
                           const std::vector<int64_t>& slot_lengths,
                           const uint64_t total_length,
                           const int batch_size,
                           size_t grad_value_size,
L
lxsbupt 已提交
1012 1013
                           std::vector<int>& slot_vector,           // NOLINT
                           std::vector<int>& slot_mf_dim_vector) {  // NOLINT
D
danleifeng 已提交
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
    CopyForPushImpl(place,
                    grad_values,
                    total_grad_values_gpu,
                    slot_lengths,
                    total_length,
                    batch_size,
                    grad_value_size,
                    slot_vector,
                    slot_mf_dim_vector);
  }

D
danleifeng 已提交
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
  virtual void CopyForPush(const paddle::platform::Place& place,
                           const uint64_t* total_keys,
                           float** grad_values,
                           float* total_grad_values_gpu,
                           const int* slots,
                           const int64_t* slot_lens,
                           const int hidden_size,
                           const int64_t total_length,
                           const int64_t dedup_length,
                           const int batch_size,
                           const int* slot_dims,
                           const int* key2slot,
                           const uint32_t* d_restore_idx,
                           const size_t grad_value_size) {
    CopyForPushDedupImpl(place,
                         total_keys,
                         grad_values,
                         total_grad_values_gpu,
                         slots,
                         slot_lens,
                         hidden_size,
                         total_length,
                         dedup_length,
                         batch_size,
                         slot_dims,
                         key2slot,
                         d_restore_idx,
                         grad_value_size);
  }

  virtual void CopyForPush(const paddle::platform::Place& place,
                           const uint64_t* total_keys,
                           float** grad_values,
                           float* total_grad_values_gpu,
                           const int* slots,
                           const int64_t* slot_lens,
                           const int hidden_size,
                           const int64_t total_length,
                           const int64_t dedup_length,
                           const int batch_size,
                           const int* slot_dims,
                           const int* key2slot,
                           const uint32_t* gpu_sort_idx,
                           const uint32_t* gpu_sort_offset,
                           const uint32_t* gpu_sort_lens,
                           const size_t grad_value_size) {
    CopyForPushDedupImpl(place,
                         total_keys,
                         grad_values,
                         total_grad_values_gpu,
                         slots,
                         slot_lens,
                         hidden_size,
                         total_length,
                         dedup_length,
                         batch_size,
                         slot_dims,
                         key2slot,
                         gpu_sort_idx,
                         gpu_sort_offset,
                         gpu_sort_lens,
                         grad_value_size);
  }

D
danleifeng 已提交
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
  void CopyForPullImpl(const paddle::platform::Place& place,
                       uint64_t** gpu_keys,
                       const std::vector<float*>& values,
                       const float* total_values_gpu,
                       const int64_t* gpu_len,
                       const int slot_num,
                       const int hidden_size,
                       const int64_t total_length,
                       int* gpu_dim,
                       int feature_value_size);

  void CopyForPushImpl(const paddle::platform::Place& place,
                       const std::vector<const float*>& grad_values,
                       float* total_grad_values_gpu,
                       const std::vector<int64_t>& slot_lengths,
                       const uint64_t total_length,
                       const int batch_size,
                       size_t grad_value_size,
L
lxsbupt 已提交
1107 1108
                       std::vector<int>& slot_vector,          // NOLINT
                       std::vector<int>& slot_mf_dim_vector);  // NOLINT
D
danleifeng 已提交
1109

D
danleifeng 已提交
1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152
  void CopyForPullDedupImpl(const paddle::platform::Place& place,
                            const uint64_t* total_keys,
                            float** gpu_values,
                            const float* total_values_gpu,
                            const int64_t* slot_lens,
                            const int* key2slot,
                            const int hidden_size,
                            const int64_t total_length,
                            const int* slot_dims,
                            const uint32_t* gpu_restore_idx,
                            int pull_value_size);

  void CopyForPushDedupImpl(const paddle::platform::Place& place,
                            const uint64_t* total_keys,
                            float** grad_values,
                            float* total_grad_values_gpu,
                            const int* slots,
                            const int64_t* slot_lens,
                            const int hidden_size,
                            const int64_t total_length,
                            const int64_t dedup_length,
                            const int batch_size,
                            const int* slot_dims,
                            const int* key2slot,
                            const uint32_t* d_restore_idx,
                            const size_t grad_value_size);

  void CopyForPushDedupImpl(const paddle::platform::Place& place,
                            const uint64_t* total_keys,
                            float** grad_values,
                            float* total_grad_values_gpu,
                            const int* slots,
                            const int64_t* slot_lens,
                            const int hidden_size,
                            const int64_t total_length,
                            const int64_t dedup_length,
                            const int batch_size,
                            const int* slot_dims,
                            const int* key2slot,
                            const uint32_t* gpu_sort_idx,
                            const uint32_t* gpu_sort_offset,
                            const uint32_t* gpu_sort_lens,
                            const size_t grad_value_size);
D
danleifeng 已提交
1153 1154 1155 1156 1157 1158 1159
  virtual std::string ParseToString(const float* v, int param_size) {
    return gpu_accessor_.ParseToString(v, param_size);
  }

  GPUAccessor gpu_accessor_;
};

D
danleifeng 已提交
1160
class GlobalAccessorFactory {
D
danleifeng 已提交
1161
 public:
D
danleifeng 已提交
1162 1163
  static GlobalAccessorFactory& GetInstance() {
    static GlobalAccessorFactory ins;
D
danleifeng 已提交
1164 1165 1166 1167 1168 1169 1170 1171 1172
    return ins;
  }
  void Init(std::string accessor_type) {
    if (accessor_wrapper_ptr_ != nullptr) {
      return;
    }
    if (accessor_type == "CtrDymfAccessor") {
      accessor_wrapper_ptr_ = new AccessorWrapper<CommonFeatureValueAccessor>();
    } else {
D
danleifeng 已提交
1173
      VLOG(0) << "GlobalAccessorFactory Init not support accessor_type:"
D
danleifeng 已提交
1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
              << accessor_type;
      accessor_wrapper_ptr_ = new AccessorWrapper<CommonFeatureValueAccessor>();
    }
  }
  VirtualAccessor* GetAccessorWrapper() { return accessor_wrapper_ptr_; }

 private:
  VirtualAccessor* accessor_wrapper_ptr_ = nullptr;
};

T
Thunderbrook 已提交
1184 1185
}  // end namespace framework
}  // end namespace paddle
D
danleifeng 已提交
1186

T
Thunderbrook 已提交
1187
#endif