ps_gpu_wrapper.kps 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_HETERPS
#include <xpu/runtime.h>  // NOLINT
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "xpu/kernel/cluster_header.h"  // NOLINT
#include "xpu/kernel/debug.h"           // NOLINT
#include "xpu/kernel/math.h"            // NOLINT
#include "xpu/kernel/simd.h"

namespace paddle {
namespace framework {

__global__ void PullCopy(float** dest, const FeatureValue* src,
                         const long long* len, int hidden, int slot_num,
                         int total_len, unsigned long long** keys) {
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();
  __local__ int64_t local_len[slot_num];
  GM2LM(len, local_len, slot_num * sizeof(int64_t));

  for (int i = thread_id; i < slot_num; i += nthreads) {
    // max core local memory = 8KB
    // slot's max memory size = slot_len * sizeof(FeatureValue)
    int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0];
    int read_len = min(roundup_div(1024 * 8, sizeof(FeatureValue)), slot_len);
    int dest_len = i ? local_len[i - 1] : 0;
    __local__ FeatureValue local_slot_vals[read_len];
    __local__ float local_dest_vals[read_len * hidden];
    __local__ uint64_t local_slot_keys[read_len];

    // copy read_len (length) of slots' val to LM
    for (int k = 0; k < slot_len; k += read_len) {
      int real_read_len = min(read_len, slot_len - k);
      GM2LM(src + dest_len + k, local_slot_vals,
            real_read_len * sizeof(FeatureValue));
      GM2LM(keys[i] + k, local_slot_keys, real_read_len * sizeof(uint64_t));
      for (int j = 0; j < real_read_len; j++) {
        if (local_slot_keys[j] == 0) {
          local_dest_vals[j * hidden] = 0;
          local_dest_vals[j * hidden + 1] = 0;
          local_dest_vals[j * hidden + 2] = 0;
        } else {
          local_dest_vals[j * hidden] = local_slot_vals[j].show;
          local_dest_vals[j * hidden + 1] = local_slot_vals[j].clk;
          local_dest_vals[j * hidden + 2] = local_slot_vals[j].lr;
        }

        if (local_slot_vals[j].mf_size == 0 || local_slot_keys[j] == 0) {
          for (int m = 0; m < hidden - 3; m++) {
            local_dest_vals[j * hidden + 3 + m] = 0;
          }
        } else {
          for (int m = 0; m < hidden - 3; m++) {
            local_dest_vals[j * hidden + 3 + m] = local_slot_vals[j].mf[1 + m];
          }
        }
      }
      LM2GM(local_dest_vals, dest[i] + k * hidden,
            real_read_len * hidden * sizeof(float));
    }
  }
}

__global__ void CopyKeysKernel(unsigned long long** src_keys,
                               unsigned long long* dest_total_keys,
                               const long long* len, int slot_num,
                               int total_len) {
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();
  __local__ int64_t local_len[slot_num];
  GM2LM(len, local_len, slot_num * sizeof(int64_t));

  for (int i = thread_id; i < slot_num; i += nthreads) {
    // max core local memory = 8KB
    int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0];
    int read_len = min(slot_len, 1024);
    int dest_len = i ? local_len[i - 1] : 0;
    __local__ uint64_t local_slot_keys[read_len];

    for (int k = 0; k < slot_len; k += read_len) {
      int real_read_len = min(read_len, slot_len - k);
      GM2LM(src_keys[i] + k, local_slot_keys, real_read_len * sizeof(uint64_t));
      LM2GM(local_slot_keys, dest_total_keys + dest_len + k,
            real_read_len * sizeof(uint64_t));
    }
  }
}

__global__ void PushCopy(FeaturePushValue* dest, float** src, long long* len,
                         int hidden, int slot_num, int total_len, int bs,
                         int* slot_vector) {
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();
  __local__ int64_t local_len[slot_num];
  __local__ int local_slot[slot_num];
  GM2LM(len, local_len, slot_num * sizeof(int64_t));
  GM2LM(slot_vector, local_slot, slot_num * sizeof(int));

  for (int i = thread_id; i < slot_num; i += nthreads) {
    int slot_len = i ? local_len[i] - local_len[i - 1] : local_len[0];

    // max core local memory = 8KB
    // slot's max memory size = slot_len * hidden * 8
    int read_len = min(roundup_div(1024, hidden), slot_len);
    int dest_len = i ? local_len[i - 1] : 0;
    __local__ float local_slot_grads[read_len * hidden];
    __local__ FeaturePushValue local_dest_grads[read_len];

    // copy read_len(length) of slots' grad to LM
    for (int k = 0; k < slot_len; k += read_len) {
      int real_read_len = min(read_len, slot_len - k);
      GM2LM(src[i] + k * hidden, local_slot_grads,
            real_read_len * hidden * sizeof(float));
      // copy from slots' grad to total grad
      for (int j = 0; j < real_read_len; j++) {
        local_dest_grads[j].slot = local_slot[i];
        local_dest_grads[j].show = local_slot_grads[j * hidden];
        local_dest_grads[j].clk = local_slot_grads[j * hidden + 1];
        local_dest_grads[j].lr_g = local_slot_grads[j * hidden + 2] * -1. * bs;
        for (int m = 0; m < hidden - 3; m++) {
          local_dest_grads[j].mf_g[m] =
              local_slot_grads[j * hidden + 3 + m] * -1. * bs;
        }
      }
      LM2GM(local_dest_grads, dest + dest_len + k,
            real_read_len * sizeof(FeaturePushValue));
    }
  }
}

164
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
                               uint64_t** gpu_keys,
                               const std::vector<float*>& values,
                               const FeatureValue* total_values_gpu,
                               const int64_t* gpu_len, const int slot_num,
                               const int hidden_size,
                               const int64_t total_length) {
  XPUStream stream = nullptr;
  auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
  stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
               ->x_context()
               ->xpu_stream;
  float* buf_value = nullptr;
  xpu_malloc(reinterpret_cast<void**>(&buf_value),
             values.size() * sizeof(float*));
  float** gpu_values = reinterpret_cast<float**>(&buf_value);
  xpu_memcpy(gpu_values, values.data(), values.size() * sizeof(float*),
             XPU_HOST_TO_DEVICE);

  unsigned long long** c_keys = (unsigned long long**)gpu_keys;
  const long long* c_len = (const long long*)gpu_len;
  PullCopy<<<2, 64, stream>>>(gpu_values, total_values_gpu, c_len, hidden_size,
                              slot_num, total_length, c_keys);

  xpu_wait(stream);
}

void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
                            uint64_t** origin_keys, uint64_t* total_keys,
                            const int64_t* gpu_len, int slot_num,
                            int total_len) {
  XPUStream stream = nullptr;
  auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
  stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
               ->x_context()
               ->xpu_stream;
  unsigned long long** o_keys = (unsigned long long**)origin_keys;
  unsigned long long* t_keys = (unsigned long long*)total_keys;
  const long long* c_len = (const long long*)gpu_len;
  CopyKeysKernel<<<2, 64, stream>>>(o_keys, t_keys, c_len, slot_num, total_len);
  xpu_wait(stream);
}

void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
                               const std::vector<const float*>& grad_values,
                               FeaturePushValue* total_grad_values_gpu,
                               const std::vector<int64_t>& slot_lengths,
                               const int hidden_size,
                               const int64_t total_length,
                               const int batch_size) {
  XPUStream stream = nullptr;
  auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
  stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
               ->x_context()
               ->xpu_stream;
  auto slot_lengths_lod = slot_lengths;
  for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }

  float* buf_grad_value = nullptr;
  int64_t* buf_length = nullptr;
  int* buf_slot_vector = nullptr;

  xpu_malloc(reinterpret_cast<void**>(&buf_grad_value),
             grad_values.size() * sizeof(float*));
  xpu_malloc(reinterpret_cast<void**>(&buf_length),
             slot_lengths.size() * sizeof(int64_t));
  xpu_malloc(reinterpret_cast<void**>(&buf_slot_vector),
             slot_lengths_lod.size() * sizeof(int));

  float** gpu_values = reinterpret_cast<float**>(&buf_grad_value);
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length);
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector);
  xpu_memcpy(gpu_values, grad_values.data(),
             grad_values.size() * sizeof(float*), XPU_HOST_TO_DEVICE);
  xpu_memcpy(gpu_len, slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t), XPU_HOST_TO_DEVICE);
  xpu_memcpy(d_slot_vector, slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int), XPU_HOST_TO_DEVICE);

  long long* c_len = (long long*)gpu_len;
  PushCopy<<<2, 64, stream>>>(total_grad_values_gpu, gpu_values, c_len,
                              hidden_size, slot_lengths.size(), total_length,
                              batch_size, d_slot_vector);
  xpu_wait(stream);
}

void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
                                float min_bound, float max_bound,
                                float learning_rate, float initial_g2sum,
                                float initial_range) {
258 259 260 261 262 263 264 265 266
  OptimizerConfig optimizer_config;
  optimizer_config.nonclk_coeff = nonclk_coeff;
  optimizer_config.clk_coeff = clk_coeff;
  optimizer_config.min_bound = min_bound;
  optimizer_config.max_bound = max_bound;
  optimizer_config.learning_rate = learning_rate;
  optimizer_config.initial_g2sum = initial_g2sum;
  optimizer_config.initial_range = initial_range;
  HeterPs_->set_sparse_sgd(optimizer_config);
267 268 269 270 271 272
}

void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
                                float mf_learning_rate, float mf_initial_g2sum,
                                float mf_initial_range, float mf_min_bound,
                                float mf_max_bound) {
273 274 275 276 277 278 279 280
  OptimizerConfig optimizer_config;
  optimizer_config.mf_create_thresholds = mf_create_thresholds;
  optimizer_config.mf_learning_rate = mf_learning_rate;
  optimizer_config.mf_initial_g2sum = mf_initial_g2sum;
  optimizer_config.mf_initial_range = mf_initial_range;
  optimizer_config.mf_min_bound = mf_min_bound;
  optimizer_config.mf_max_bound = mf_max_bound;
  HeterPs_->set_embedx_sgd(optimizer_config);
281 282 283 284 285
}

}  // end namespace framework
}  // end namespace paddle
#endif