hashtable_kernel.cu 19.5 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16
#include <thread>
17

18 19
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
26 27 28 29 30 31 32 33
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

L
lxsbupt 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              size_t len,
                              uint64_t* global_num) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  __shared__ uint64_t local_num;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.x == 0) {
    local_num = 0;
  }
  __syncthreads();

  if (i < len) {
    kv.first = keys[i];
    kv.second = 1;  // fake value
    auto it = table->insert(kv, op, &local_num);
    assert(it != table->end() && "error: insert fails: table is full");
  }
  __syncthreads();

  if (threadIdx.x == 0) {
    atomicAdd(global_num, local_num);
  }
}

T
Thunderbrook 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

79 80 81
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
82 83 84
                              size_t len,
                              char* pool,
                              size_t feature_value_size,
Y
yaoxuefeng 已提交
85
                              int start_index) {
86 87 88 89 90 91 92
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
93 94
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
95
    auto it = table->insert(kv, op);
L
lxsbupt 已提交
96 97 98
    if (it == table->end()) {
      printf("error: insert fails: table is full");
    }
99 100 101
  }
}

T
Thunderbrook 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

L
lxsbupt 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
template <typename Table, typename GPUAccessor>
__global__ void dy_mf_search_kernel_fill(
    Table* table,
    const typename Table::key_type* const keys,
    char* vals,
    size_t len,
    size_t pull_feature_value_size,
    GPUAccessor gpu_accessor) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      uint64_t offset = i * pull_feature_value_size;
      float* cur = reinterpret_cast<float*>(vals + offset);
      float* input = it->second;
      gpu_accessor.PullValueFill(cur, input);
    } else {
      float* cur = reinterpret_cast<float*>(&vals[i * pull_feature_value_size]);
      gpu_accessor.PullZeroValue(cur);
    }
  }
}

D
danleifeng 已提交
139
template <typename Table, typename GPUAccessor>
140 141
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
142 143
                                    char* vals,
                                    size_t len,
D
danleifeng 已提交
144
                                    size_t pull_feature_value_size,
D
danleifeng 已提交
145
                                    GPUAccessor gpu_accessor) {
146
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
147 148 149
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Y
yaoxuefeng 已提交
150
      uint64_t offset = i * pull_feature_value_size;
L
lxsbupt 已提交
151
      float* cur = reinterpret_cast<float*>(vals + offset);
D
danleifeng 已提交
152
      float* input = it->second;
D
danleifeng 已提交
153
      gpu_accessor.PullValueFill(cur, input);
L
lxsbupt 已提交
154 155
    } else {
      printf("warning: pull miss key: %lu", keys[i]);
156 157 158
    }
  }
}
159

T
Thunderbrook 已提交
160 161
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
162
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
163
                              const typename Table::key_type* const keys,
164 165
                              const GradType* const grads,
                              size_t len,
T
Thunderbrook 已提交
166 167 168 169 170
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
171
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
172 173 174 175
    }
  }
}

176 177
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
178
                                    const OptimizerConfig& optimizer_config,
179
                                    const typename Table::key_type* const keys,
180 181 182 183
                                    const char* const grads,
                                    size_t len,
                                    Sgd sgd,
                                    size_t grad_value_size) {
184 185 186 187
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
L
lxsbupt 已提交
188 189
      const float* cur =
          reinterpret_cast<const float*>(grads + i * grad_value_size);
D
danleifeng 已提交
190
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, cur);
191
    } else {
D
danleifeng 已提交
192
      printf("warning: push miss key: %lu", keys[i]);
193 194 195 196
    }
  }
}

L
lxsbupt 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
template <typename Table>
__global__ void get_keys_kernel(Table* table,
                                typename Table::key_type* d_out,
                                uint64_t* global_cursor,
                                uint64_t unused_key) {
  extern __shared__ typename Table::key_type local_key[];
  __shared__ uint64_t local_num;
  __shared__ uint64_t global_num;

  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.x == 0) {
    local_num = 0;
  }
  __syncthreads();
  uint64_t len = table->size();
  if (idx < len) {
    typename Table::value_type val = *(table->data() + idx);
    if (val.first != unused_key) {
      uint64_t dst = atomicAdd(&local_num, 1);
      local_key[dst] = val.first;
    }
  }

  __syncthreads();

  if (threadIdx.x == 0) {
    global_num = atomicAdd(global_cursor, local_num);
  }
  __syncthreads();

  if (threadIdx.x < local_num) {
    d_out[global_num + threadIdx.x] = local_key[threadIdx.x];
  }
}

T
Thunderbrook 已提交
232 233 234
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
L
lxsbupt 已提交
235 236
  CUDA_RT_CALL(cudaMalloc(&device_optimizer_config_, sizeof(OptimizerConfig)));
  CUDA_RT_CALL(cudaMemcpy(device_optimizer_config_,
D
danleifeng 已提交
237 238 239
                          &host_optimizer_config_,
                          sizeof(OptimizerConfig),
                          cudaMemcpyHostToDevice));
240
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
241 242 243 244 245
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
Z
zmxdream 已提交
246
  cudaFree(device_optimizer_config_);
T
Thunderbrook 已提交
247 248
}

Z
zmxdream 已提交
249 250 251 252
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
L
lxsbupt 已提交
253
  cudaMemcpy(device_optimizer_config_,
254 255 256
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
257 258 259 260 261 262
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
L
lxsbupt 已提交
263
  cudaMemcpy(device_optimizer_config_,
264 265 266
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
267 268
}

T
Thunderbrook 已提交
269 270 271 272 273 274
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
275
template <typename StreamType>
276 277 278 279
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      ValType* d_vals,
                                      size_t len,
                                      StreamType stream) {
T
Thunderbrook 已提交
280 281 282 283
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
284 285
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
286 287
}

288
template <typename KeyType, typename ValType>
D
danleifeng 已提交
289
template <typename StreamType, typename GPUAccessor>
290 291 292
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      char* d_vals,
                                      size_t len,
D
danleifeng 已提交
293
                                      StreamType stream,
L
lxsbupt 已提交
294
                                      const GPUAccessor& fv_accessor) {
295 296 297
  if (len == 0) {
    return;
  }
298
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
L
lxsbupt 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
  // infer need zero fill
  if (infer_mode_) {
    dy_mf_search_kernel_fill<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
        container_, d_keys, d_vals, len, pull_feature_value_size_, fv_accessor);
  } else {
    dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
        container_, d_keys, d_vals, len, pull_feature_value_size_, fv_accessor);
  }
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         size_t len,
                                         uint64_t* global_num,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, global_num);
321 322
}

T
Thunderbrook 已提交
323
template <typename KeyType, typename ValType>
324
template <typename StreamType>
T
Thunderbrook 已提交
325
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
326 327
                                         const ValType* d_vals,
                                         size_t len,
328
                                         StreamType stream) {
T
Thunderbrook 已提交
329 330 331 332
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
333 334
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
335 336
}

L
lxsbupt 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get_keys(KeyType* d_out,
                                           uint64_t* global_cursor,
                                           StreamType stream) {
  size_t len = container_->size();
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  size_t shared_mem_size = sizeof(KeyType) * BLOCK_SIZE_;
  get_keys_kernel<<<grid_size, BLOCK_SIZE_, shared_mem_size, stream>>>(
      container_, d_out, global_cursor, unuse_key);
}

350
template <typename KeyType, typename ValType>
351
template <typename StreamType>
352 353 354 355
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         size_t len,
                                         char* pool,
                                         size_t feature_value_size,
Y
yaoxuefeng 已提交
356
                                         size_t start_index,
357
                                         StreamType stream) {
358 359 360 361 362 363
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
364
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
365 366
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
367 368
}

T
Thunderbrook 已提交
369
template <typename KeyType, typename ValType>
370 371
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
372 373 374
  container_->prefetch(cudaCpuDeviceId, stream);
}

T
Thunderbrook 已提交
375
template <typename KeyType, typename ValType>
D
danleifeng 已提交
376
template <typename Sgd, typename StreamType>
T
Thunderbrook 已提交
377
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
D
danleifeng 已提交
378
                                         const float* d_grads,
379 380 381
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
T
Thunderbrook 已提交
382 383 384 385
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
386 387
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
388 389
}

390
template <typename KeyType, typename ValType>
391
template <typename Sgd, typename StreamType>
392
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
393 394 395 396
                                         const char* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
397 398 399 400 401
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
402 403 404 405 406 407
      container_,
      *device_optimizer_config_,
      d_keys,
      d_grads,
      len,
      sgd,
Z
zmxdream 已提交
408
      push_grad_value_size_);
409 410
}

L
lxsbupt 已提交
411 412 413 414 415 416 417 418 419 420 421 422 423 424
template class HashTable<uint64_t, float>;
template class HashTable<uint64_t, float*>;
template class HashTable<int64_t, int>;
template class HashTable<uint64_t, int>;
template class HashTable<uint64_t, uint64_t>;
template class HashTable<uint64_t, uint64_t*>;
template class HashTable<uint64_t, int64_t>;
template class HashTable<uint64_t, int64_t*>;
template class HashTable<int64_t, int64_t>;
template class HashTable<int64_t, uint64_t>;
template class HashTable<int64_t, unsigned int>;

template void HashTable<uint64_t, float>::get<cudaStream_t>(
    const uint64_t* d_keys, float* d_vals, size_t len, cudaStream_t stream);
425

Y
yaoxuefeng 已提交
426
template void
L
lxsbupt 已提交
427 428
HashTable<uint64_t, float*>::get<cudaStream_t, CommonFeatureValueAccessor>(
    const uint64_t* d_keys,
D
danleifeng 已提交
429 430 431
    char* d_vals,
    size_t len,
    cudaStream_t stream,
L
lxsbupt 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
    const CommonFeatureValueAccessor& fv_accessor);

template void HashTable<int64_t, int>::get<cudaStream_t>(const int64_t* d_keys,
                                                         int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);

template void HashTable<uint64_t, int>::get<cudaStream_t>(
    const uint64_t* d_keys, int* d_vals, size_t len, cudaStream_t stream);
template void HashTable<uint64_t, uint64_t>::get<cudaStream_t>(
    const uint64_t* d_keys, uint64_t* d_vals, size_t len, cudaStream_t stream);
template void HashTable<uint64_t, int64_t>::get<cudaStream_t>(
    const uint64_t* d_keys, int64_t* d_vals, size_t len, cudaStream_t stream);
template void HashTable<int64_t, uint64_t>::get<cudaStream_t>(
    const int64_t* d_keys, uint64_t* d_vals, size_t len, cudaStream_t stream);
template void HashTable<int64_t, int64_t>::get<cudaStream_t>(
    const int64_t* d_keys, int64_t* d_vals, size_t len, cudaStream_t stream);
template void HashTable<int64_t, unsigned int>::get<cudaStream_t>(
    const int64_t* d_keys,
    unsigned int* d_vals,
D
danleifeng 已提交
452 453
    size_t len,
    cudaStream_t stream);
454
// template void
L
lxsbupt 已提交
455 456
// HashTable<uint64_t, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const uint64_t* d_keys, char* d_vals, size_t len, cudaStream_t
457 458
//    stream);

L
lxsbupt 已提交
459 460
template void HashTable<uint64_t, float>::insert<cudaStream_t>(
    const uint64_t* d_keys,
D
danleifeng 已提交
461 462 463
    const float* d_vals,
    size_t len,
    cudaStream_t stream);
464

L
lxsbupt 已提交
465 466
template void HashTable<uint64_t, float*>::insert<cudaStream_t>(
    const uint64_t* d_keys,
D
danleifeng 已提交
467 468 469 470 471
    size_t len,
    char* pool,
    size_t feature_value_size,
    size_t start_index,
    cudaStream_t stream);
Y
yaoxuefeng 已提交
472

L
lxsbupt 已提交
473 474 475 476 477
template void HashTable<int64_t, int>::insert<cudaStream_t>(
    const int64_t* d_keys, const int* d_vals, size_t len, cudaStream_t stream);
template void HashTable<int64_t, int64_t>::insert<cudaStream_t>(
    const int64_t* d_keys,
    const int64_t* d_vals,
478
    size_t len,
T
Thunderbrook 已提交
479
    cudaStream_t stream);
D
danleifeng 已提交
480

L
lxsbupt 已提交
481 482 483 484 485 486
template void HashTable<uint64_t, int>::insert<cudaStream_t>(
    const uint64_t* d_keys, const int* d_vals, size_t len, cudaStream_t stream);

template void HashTable<uint64_t, int64_t>::insert<cudaStream_t>(
    const uint64_t* d_keys,
    const int64_t* d_vals,
D
danleifeng 已提交
487 488 489
    size_t len,
    cudaStream_t stream);

L
lxsbupt 已提交
490 491 492
template void HashTable<int64_t, uint64_t>::insert<cudaStream_t>(
    const int64_t* d_keys,
    const uint64_t* d_vals,
493
    size_t len,
494 495
    cudaStream_t stream);

L
lxsbupt 已提交
496 497
template void HashTable<int64_t, unsigned int>::insert<cudaStream_t>(
    const int64_t* d_keys,
498 499
    const unsigned int* d_vals,
    size_t len,
500 501
    cudaStream_t stream);

L
lxsbupt 已提交
502 503 504 505 506 507 508 509 510 511 512 513
template void HashTable<uint64_t, uint64_t>::get_keys<cudaStream_t>(
    uint64_t* d_out, uint64_t* global_cursor, cudaStream_t stream);

template void HashTable<uint64_t, uint64_t>::insert<cudaStream_t>(
    const uint64_t* d_keys,
    uint64_t len,
    uint64_t* global_num,
    cudaStream_t stream);

template void HashTable<uint64_t, uint64_t>::insert<cudaStream_t>(
    const uint64_t* d_keys,
    const uint64_t* d_vals,
D
danleifeng 已提交
514 515
    size_t len,
    cudaStream_t stream);
516

L
lxsbupt 已提交
517
template void HashTable<uint64_t, float*>::dump_to_cpu<cudaStream_t>(
D
danleifeng 已提交
518 519
    int devid, cudaStream_t stream);

L
lxsbupt 已提交
520
template void HashTable<uint64_t, float*>::update<
D
danleifeng 已提交
521
    SparseAdagradOptimizer<CommonFeatureValueAccessor>,
L
lxsbupt 已提交
522
    cudaStream_t>(const uint64_t* d_keys,
D
danleifeng 已提交
523 524 525 526
                  const char* d_grads,
                  size_t len,
                  SparseAdagradOptimizer<CommonFeatureValueAccessor> sgd,
                  cudaStream_t stream);
L
lxsbupt 已提交
527
template void HashTable<uint64_t, float*>::update<
D
danleifeng 已提交
528
    SparseAdamOptimizer<CommonFeatureValueAccessor>,
L
lxsbupt 已提交
529
    cudaStream_t>(const uint64_t* d_keys,
D
danleifeng 已提交
530 531 532 533
                  const char* d_grads,
                  size_t len,
                  SparseAdamOptimizer<CommonFeatureValueAccessor> sgd,
                  cudaStream_t stream);
L
lxsbupt 已提交
534
template void HashTable<uint64_t, float*>::update<
D
danleifeng 已提交
535
    SparseAdamSharedOptimizer<CommonFeatureValueAccessor>,
L
lxsbupt 已提交
536
    cudaStream_t>(const uint64_t* d_keys,
D
danleifeng 已提交
537
                  const char* d_grads,
538
                  size_t len,
D
danleifeng 已提交
539
                  SparseAdamSharedOptimizer<CommonFeatureValueAccessor> sgd,
Y
yaoxuefeng 已提交
540 541
                  cudaStream_t stream);

L
lxsbupt 已提交
542
// template void HashTable<uint64_t,
543 544 545
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
L
lxsbupt 已提交
546
//    cudaStream_t>(const uint64_t* d_keys, const char* d_grads, size_t
547 548 549 550 551 552 553
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
554 555 556
}  // end namespace framework
}  // end namespace paddle
#endif