ps_gpu_wrapper.cu 6.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
16 17 18 19
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
20

Y
yaoxuefeng 已提交
21
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
22 23
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
24
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
D
danleifeng 已提交
25
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
T
Thunderbrook 已提交
26 27 28 29

namespace paddle {
namespace framework {

D
danleifeng 已提交
30 31 32 33
const int CUDA_NUM_THREADS = platform::PADDLE_CUDA_NUM_THREADS;
#define GET_BLOCK(N) ((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS)
#define CUDA_BLOCK(N) GET_BLOCK(N), CUDA_NUM_THREADS, 0

34 35 36 37
__global__ void CopyKeysKernel(uint64_t** src_keys,
                               uint64_t* dest_total_keys,
                               const int64_t* len,
                               int slot_num,
T
Thunderbrook 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
                               int total_len) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    dest_total_keys[i] = src_keys[x][y];
  }
}

55 56 57 58 59 60 61
__global__ void PushCopy(FeaturePushValue* dest,
                         float** src,
                         int64_t* len,
                         int hidden,
                         int slot_num,
                         int total_len,
                         int bs,
T
Thunderbrook 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
                         int* slot_vector) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    (dest + i)->slot = slot_vector[x];
    (dest + i)->show = *(src[x] + y * hidden);
    (dest + i)->clk = *(src[x] + y * hidden + 1);
    (dest + i)->lr_g = *(src[x] + y * hidden + 2) * -1. * bs;
Y
yaoxuefeng 已提交
79
    for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
80 81 82 83 84
      (dest + i)->mf_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
    }
  }
}

F
Fan Zhang 已提交
85 86
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }

T
Thunderbrook 已提交
87
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
88 89 90 91
                            uint64_t** origin_keys,
                            uint64_t* total_keys,
                            const int64_t* gpu_len,
                            int slot_num,
T
Thunderbrook 已提交
92
                            int total_len) {
L
Leo Chen 已提交
93
  auto stream = dynamic_cast<phi::GPUContext*>(
94
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
95
                    ->stream();
96
  CopyKeysKernel<<<(total_len + 1024 - 1) / 1024, 1024, 0, stream>>>(
T
Thunderbrook 已提交
97 98 99 100
      origin_keys, total_keys, gpu_len, slot_num, total_len);
  cudaStreamSynchronize(stream);
}

D
danleifeng 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
__global__ void CopyKeysKernel2(const int total_len,
                                uint64_t** src_keys,
                                uint64_t* dest_total_keys,
                                const int slot_num,
                                const int64_t* slot_lens,
                                int* key2slots) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < slot_lens[mid + 1]) {
        high = mid;
      } else {
        low = mid + 1;
      }
    }
    key2slots[i] = low;
    int y = i - slot_lens[low];
    dest_total_keys[i] = src_keys[low][y];
  }
}

void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
                            uint64_t** origin_keys,
                            uint64_t* total_keys,
                            const int64_t* slot_lens,
                            int slot_num,
                            int total_len,
                            int* key2slot) {
131
  auto stream = dynamic_cast<phi::GPUContext*>(
D
danleifeng 已提交
132 133 134 135 136 137 138
                    platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  CopyKeysKernel2<<<CUDA_BLOCK(total_len), stream>>>(
      total_len, origin_keys, total_keys, slot_num, slot_lens, key2slot);
  cudaStreamSynchronize(stream);
}

139 140 141 142 143 144
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff,
                                float clk_coeff,
                                float min_bound,
                                float max_bound,
                                float learning_rate,
                                float initial_g2sum,
D
danleifeng 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158
                                float initial_range,
                                float beta1_decay_rate,
                                float beta2_decay_rate,
                                float ada_epsilon) {
  optimizer_config_.set_sparse_sgd(nonclk_coeff,
                                   clk_coeff,
                                   min_bound,
                                   max_bound,
                                   learning_rate,
                                   initial_g2sum,
                                   initial_range,
                                   beta1_decay_rate,
                                   beta2_decay_rate,
                                   ada_epsilon);
Y
yaoxuefeng 已提交
159 160 161
}

void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
162 163 164 165
                                float mf_learning_rate,
                                float mf_initial_g2sum,
                                float mf_initial_range,
                                float mf_min_bound,
D
danleifeng 已提交
166 167 168
                                float mf_max_bound,
                                float mf_beta1_decay_rate,
                                float mf_beta2_decay_rate,
D
danleifeng 已提交
169 170 171
                                float mf_ada_epsilon,
                                float nodeid_slot,
                                float feature_learning_rate) {
D
danleifeng 已提交
172 173 174 175 176 177 178 179
  optimizer_config_.set_embedx_sgd(mf_create_thresholds,
                                   mf_learning_rate,
                                   mf_initial_g2sum,
                                   mf_initial_range,
                                   mf_min_bound,
                                   mf_max_bound,
                                   mf_beta1_decay_rate,
                                   mf_beta2_decay_rate,
D
danleifeng 已提交
180 181 182
                                   mf_ada_epsilon,
                                   nodeid_slot,
                                   feature_learning_rate);
Y
yaoxuefeng 已提交
183 184
}

T
Thunderbrook 已提交
185 186 187
}  // end namespace framework
}  // end namespace paddle
#endif