未验证 提交 3a2fb4cf 编写于 作者: Z zmxdream 提交者: GitHub

[cherry-pick]XPUPS add support for kunlun2 (#41916)

* [XPUPS]add support for kunlun2 (#40985)


[XPUPS]add support for kunlun2
Co-authored-by: NWorgenZhang <frank08081993@gmail.com>

* [XPUPS]fix hashtable_kernel.kps (#41790)

* refactor heter comm kernel

* update. test=develop

* update calc_shard_offset. test=develop

* update xpu kernel. test=develop

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* update. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* add optimizer kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update hashtable. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* template init. test=develop

* hashtable template init. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix hashtable_kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop
Co-authored-by: NWorgenZhang <frank08081993@gmail.com>

* [XPUPS]modify xpu_kp.cmake with HETERPS&PSLIB (#41760)

* modify xpu_kp.cmake with HETERPS&PSLIB

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop
Co-authored-by: NWorgenZhang <frank08081993@gmail.com>
上级 8ccdb91b
......@@ -122,6 +122,12 @@ macro(compile_kernel COMPILE_ARGS)
string(REPLACE ";" " " XPU_CXX_DEFINES "${XPU_CXX_DEFINES}" )
separate_arguments(XPU_CXX_DEFINES UNIX_COMMAND "${XPU_CXX_DEFINES}")
set(ABI_VERSION "")
if(WITH_HETERPS AND WITH_PSLIB)
set(ABI_VERSION "-D_GLIBCXX_USE_CXX11_ABI=0")
else()
set(ABI_VERSION "-D_GLIBCXX_USE_CXX11_ABI=1")
endif()
add_custom_command(
OUTPUT
kernel_build/${kernel_name}.bin.o
......@@ -130,7 +136,7 @@ macro(compile_kernel COMPILE_ARGS)
COMMAND
${CMAKE_COMMAND} -E copy ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu
COMMAND
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 ${ABI_VERSION} ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
-I. -o kernel_build/${kernel_name}.bin.o.sec kernel_build/${kernel_name}.xpu
--xpu-device-only -c -v
COMMAND
......@@ -153,7 +159,7 @@ macro(compile_kernel COMPILE_ARGS)
COMMAND
${CMAKE_COMMAND} -E copy ${kernel_path}/${kernel_name}.kps kernel_build/${kernel_name}.xpu
COMMAND
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1 ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
${XPU_CLANG} --sysroot=${CXX_DIR} -std=c++11 ${ABI_VERSION} ${OPT_LEVEL} -fno-builtin -mcpu=xpu2 -fPIC ${XPU_CXX_DEFINES} ${XPU_CXX_FLAGS} ${XPU_CXX_INCLUDES}
-I. -o kernel_build/${kernel_name}.host.o kernel_build/${kernel_name}.xpu
--xpu-host-only -c -v
WORKING_DIRECTORY
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include <vector>
#ifdef PADDLE_WITH_PSLIB
#include "common_value.h" // NOLINT
#include "common/common_value.h" // NOLINT
#endif
#ifdef PADDLE_WITH_PSCORE
......
......@@ -7,7 +7,9 @@ IF(WITH_GPU)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
SET(HETERPS_DEPS ${HETERPS_DEPS} ${RPC_DEPS})
endif()
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_library(heter_comm_kernel SRCS heter_comm_kernel.cu feature_value.h DEPS ${HETERPS_DEPS})
nv_library(hashtable_kernel SRCS hashtable_kernel.cu feature_value.h DEPS ${HETERPS_DEPS})
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h mem_pool.h DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel)
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
if(WITH_PSCORE)
......
......@@ -52,18 +52,18 @@ struct FeaturePushValue {
float lr_g;
float mf_g[MF_DIM];
__device__ __forceinline__ FeaturePushValue
operator+(const FeaturePushValue& a) const {
FeaturePushValue out;
out.slot = a.slot;
out.show = a.show + show;
out.clk = a.clk + clk;
out.lr_g = a.lr_g + lr_g;
for (int i = 0; i < MF_DIM; ++i) {
out.mf_g[i] = a.mf_g[i] + mf_g[i];
}
return out;
}
// __device__ __forceinline__ FeaturePushValue
// operator+(const FeaturePushValue& a) const {
// FeaturePushValue out;
// out.slot = a.slot;
// out.show = a.show + show;
// out.clk = a.clk + clk;
// out.lr_g = a.lr_g + lr_g;
// for (int i = 0; i < MF_DIM; ++i) {
// out.mf_g[i] = a.mf_g[i] + mf_g[i];
// }
// return out;
// }
};
} // end namespace framework
......
......@@ -13,28 +13,38 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h>
#include <limits>
#include <memory>
#include <vector>
#ifdef PADDLE_WITH_PSLIB
#include "common_value.h" // NOLINT
#endif
#ifdef PADDLE_WITH_PSCORE
#if defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#endif
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/phi/core/utils/rw_lock.h"
#include "thrust/pair.h"
// #include "cudf/concurrent_unordered_map.cuh.h"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "thrust/pair.h"
#elif defined(__xpu__)
#include <xpu/runtime.h>
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd.h"
#endif
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_CUDA)
template <typename KeyType, typename ValType>
class TableContainer
: public concurrent_unordered_map<KeyType, ValType,
......@@ -45,31 +55,84 @@ class TableContainer
std::numeric_limits<KeyType>::max()>(
capacity, ValType()) {}
};
#elif defined(PADDLE_WITH_XPU_KP)
template <typename KeyType, typename ValType>
class XPUCacheArray {
public:
explicit XPUCacheArray(size_t capacity) : capacity_(capacity), size_(0) {
xpu_malloc(reinterpret_cast<void**>(&keys), capacity_ * sizeof(KeyType));
xpu_malloc(reinterpret_cast<void**>(&vals), capacity_ * sizeof(ValType));
}
virtual ~XPUCacheArray() {
xpu_free(keys);
xpu_free(vals);
}
void print() {}
// ValType* find(const KeyType& key) { return NULL; }
// bool insert(const KeyType& key, const ValType& val) { return true; }
int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; }
size_t size() { return size_; }
private:
long long capacity_;
long long size_;
KeyType* keys;
ValType* vals;
};
#endif
template <typename KeyType, typename ValType>
class HashTable {
public:
HashTable(size_t capacity);
explicit HashTable(size_t capacity);
virtual ~HashTable();
HashTable(const HashTable&) = delete;
HashTable& operator=(const HashTable&) = delete;
template <typename StreamType>
void insert(const KeyType* d_keys, const ValType* d_vals, size_t len,
gpuStream_t stream);
StreamType stream);
template <typename StreamType>
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
gpuStream_t stream);
StreamType stream);
template <typename StreamType>
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
gpuStream_t stream);
void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream);
StreamType stream);
template <typename StreamType>
void get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream);
void show();
void dump_to_cpu(int devid, cudaStream_t stream);
template <typename GradType, typename Sgd>
template <typename StreamType>
void dump_to_cpu(int devid, StreamType stream);
#if defined(PADDLE_WITH_CUDA)
template <typename GradType, typename Sgd, typename StreamType>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream);
Sgd sgd, StreamType stream);
template <typename Sgd>
template <typename Sgd, typename StreamType>
void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd,
gpuStream_t stream);
StreamType stream);
#elif defined(PADDLE_WITH_XPU_KP)
template <typename GradType, typename StreamType>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
StreamType stream);
template <typename StreamType>
void update(const KeyType* d_keys, const char* d_grads, size_t len,
StreamType stream);
#endif
int size() { return container_->size(); }
......@@ -84,7 +147,11 @@ class HashTable {
std::unique_ptr<phi::RWLock> rwlock_{nullptr};
private:
#if defined(PADDLE_WITH_CUDA)
TableContainer<KeyType, ValType>* container_;
#elif defined(PADDLE_WITH_XPU_KP)
XPUCacheArray<KeyType, ValType>* container_;
#endif
int BLOCK_SIZE_{256};
float LOAD_FACTOR{0.75f};
size_t capacity_;
......@@ -94,5 +161,4 @@ class HashTable {
};
} // end namespace framework
} // end namespace paddle
#include "hashtable_inl.h"
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_HETERPS
#include <thread>
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_CUDA)
template <typename value_type>
struct ReplaceOp {
__host__ __device__ value_type operator()(value_type new_value,
......@@ -87,6 +92,7 @@ __global__ void dy_mf_search_kernel(Table* table,
}
}
}
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
const typename Table::key_type* const keys,
......@@ -135,8 +141,9 @@ void HashTable<KeyType, ValType>::show() {
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
size_t len, gpuStream_t stream) {
size_t len, StreamType stream) {
if (len == 0) {
return;
}
......@@ -146,8 +153,9 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
size_t len, gpuStream_t stream) {
size_t len, StreamType stream) {
if (len == 0) {
return;
}
......@@ -157,9 +165,10 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
const ValType* d_vals, size_t len,
gpuStream_t stream) {
StreamType stream) {
if (len == 0) {
return;
}
......@@ -169,22 +178,24 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index,
gpuStream_t stream) {
StreamType stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
if (pool == NULL) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
pool, start_index);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
container_->prefetch(cudaCpuDeviceId, stream);
std::vector<std::thread> threads;
size_t num = container_->size();
......@@ -260,10 +271,10 @@ void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
}
template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd>
template <typename GradType, typename Sgd, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
Sgd sgd, StreamType stream) {
if (len == 0) {
return;
}
......@@ -273,19 +284,66 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}
template <typename KeyType, typename ValType>
template <typename Sgd>
template <typename Sgd, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const char* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
Sgd sgd, StreamType stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
cudaStream_t>(const unsigned long* d_keys,
paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
// const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
// stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
cudaStream_t>(const unsigned long* d_keys,
const paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
// cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
// size_t start_index, cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::
dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
paddle::framework::FeaturePushValue,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
cudaStream_t>(const unsigned long* d_keys,
const paddle::framework::FeaturePushValue* d_grads,
size_t len, Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>
sgd,
cudaStream_t stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue,
// paddle::framework::FeaturePushValue>,
// cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
// len,
// Optimizer<paddle::framework::FeatureValue,
// paddle::framework::FeaturePushValue>
// sgd,
// cudaStream_t stream);
#endif
} // end namespace framework
} // end namespace paddle
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
namespace optimizer_config {
extern _global_ptr_ float* nonclk_coeff;
extern _global_ptr_ float* clk_coeff;
extern _global_ptr_ float* min_bound;
extern _global_ptr_ float* max_bound;
extern _global_ptr_ float* learning_rate;
extern _global_ptr_ float* initial_g2sum;
extern _global_ptr_ float* initial_range;
extern _global_ptr_ float* mf_create_thresholds;
extern _global_ptr_ float* mf_learning_rate;
extern _global_ptr_ float* mf_initial_g2sum;
extern _global_ptr_ float* mf_initial_range;
extern _global_ptr_ float* mf_min_bound;
extern _global_ptr_ float* mf_max_bound;
}
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_XPU_KP)
__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
float scale) {
__local__ float local_learning_rate;
__local__ float local_initial_g2sum;
__local__ float local_min_bound;
__local__ float local_max_bound;
GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float));
GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float));
GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float));
GM2LM(optimizr_config::max_bound, &local_max_bound, sizeof(float));
double add_g2sum = 0;
double ratio = local_learning_rate *
sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum));
double scaled_grad = g / scale;
w += scaled_grad * ratio;
if (w < local_min_bound) w = local_min_bound;
if (w > local_max_bound) w = local_max_bound;
add_g2sum += scaled_grad * scaled_grad;
g2sum += add_g2sum;
}
__device__ void update_mf(int n, float* w, float& g2sum, const float* g,
float scale) {
__local__ float local_mf_learning_rate;
__local__ float local_mf_initial_g2sum;
__local__ float local_mf_min_bound;
__local__ float local_mf_max_bound;
GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate,
sizeof(float));
GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum,
sizeof(float));
GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float));
GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float));
double add_g2sum = 0;
double ratio =
local_mf_learning_rate *
sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + g2sum));
for (int i = 0; i < n; ++i) {
double scaled_grad = g[i] / scale;
w[i] += scaled_grad * ratio;
if (w[i] < local_mf_min_bound) w[i] = local_mf_min_bound;
if (w[i] > local_mf_max_bound) w[i] = local_mf_max_bound;
add_g2sum += scaled_grad * scaled_grad;
}
g2sum += add_g2sum / n;
}
__device__ float xpu_rand_uniform() { return 0.1; }
template <typename ValType, typename GradType>
__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
val.slot = grad.slot;
val.show += grad.show;
val.clk += grad.clk;
__local__ float local_nonclk_coeff;
__local__ float local_clk_coeff;
__local__ float local_mf_create_thresholds;
__local__ float local_mf_initial_range;
GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float));
GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float));
GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
sizeof(float));
val.delta_score +=
local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;
update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);
if (val.mf_size == 0) {
if (local_mf_create_thresholds <=
local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) {
val.mf_size = MF_DIM + 1;
val.mf[0] = 0;
for (int i = 0; i < MF_DIM; ++i) {
val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
}
}
} else {
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
}
}
template <typename KeyType, typename ValType, typename Table>
__global__ void insert_kernel(Table* table, const KeyType* const keys,
const ValType* const vals, size_t len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 150;
__local__ KeyType local_keys[buf_size];
__local__ ValType local_vals[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
int read_len = min(len_per_loop, len - i);
GM2LM(keys, local_keys, read_len * sizeof(KeyType));
GM2LM(vals, local_vals, read_len * sizeof(ValType));
for (int k = 0; k < read_len; k++) {
// auto status = table->insert(local_keys[k], local_vals[k]);
// assert(status != false && "error: insert fails: table is full");
}
}
}
template <typename KeyType, typename ValType, typename Table>
__global__ void search_kernel(Table* table, const KeyType* const keys,
ValType* const vals, size_t len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 150;
__local__ KeyType local_keys[buf_size];
__local__ ValType local_vals[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
int read_len = min(len_per_loop, len - i);
GM2LM(keys, local_keys, read_len * sizeof(KeyType));
for (int k = 0; k < read_len; k++) {
// ValType* val = table->find(local_keys[k]);
// if (val != NULL) {
// local_vals[k] = *val;
// }
}
LM2GM(local_vals, vals + i, read_len * sizeof(ValType));
}
}
template <typename KeyType, typename ValType, typename Table, typename GradType>
__global__ void update_kernel(Table* table, const KeyType* const keys,
const GradType* const grads, size_t len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 250;
__local__ KeyType local_keys[buf_size];
__local__ GradType local_grads[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
int read_len = min(len_per_loop, len - i);
GM2LM(keys, local_keys, read_len * sizeof(KeyType));
GM2LM(grads, local_grads, read_len * sizeof(GradType));
for (int k = 0; k < read_len; k++) {
// ValType* val = table->find(local_keys[k]);
// if (val != NULL) {
// update_value(*val, grads[i]);
//}
}
}
}
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
auto tmp_container = XPUCacheArray<KeyType, ValType>(capacity);
xpu_malloc(reinterpret_cast<void**>(&container_),
sizeof(XPUCacheArray<KeyType, ValType>));
xpu_memcpy(container_, &tmp_container,
sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);
rwlock_.reset(new phi::RWLock);
}
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
xpu_free((void*)container_);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
container_->print();
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
size_t len, StreamType stream) {
if (len == 0) {
return;
}
search_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len);
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
size_t len, StreamType stream) {
if (len == 0) {
return;
}
// TODO(zhangminxu): to be implemented
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
const ValType* d_vals, size_t len,
StreamType stream) {
if (len == 0) {
return;
}
insert_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len);
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
// TODO(zhangminxu): to be implemented
}
template <typename KeyType, typename ValType>
template <typename GradType, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const GradType* d_grads, size_t len,
StreamType stream) {
if (len == 0) {
return;
}
update_kernel<<<4, 64, stream>>>(container_, d_keys, d_grads, len);
}
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const char* d_grads, size_t len,
StreamType stream) {
if (len == 0) {
return;
}
// TODO(zhangminxu): to be implemented
}
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
XPUStream>(const unsigned long* d_keys,
paddle::framework::FeatureValue* d_vals, size_t len,
XPUStream stream);
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<XPUStream>(
// const unsigned long* d_keys, char* d_vals, size_t len, XPUStream stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
XPUStream>(const unsigned long* d_keys,
const paddle::framework::FeatureValue* d_vals, size_t len,
XPUStream stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
// XPUStream>(const unsigned long* d_keys, size_t len, char* pool,
// size_t start_index, XPUStream stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::
dump_to_cpu<XPUStream>(int devid, XPUStream stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
paddle::framework::FeaturePushValue, XPUStream>(
const unsigned long* d_keys,
const paddle::framework::FeaturePushValue* d_grads, size_t len,
XPUStream stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
// XPUStream>(const unsigned long* d_keys, const char* d_grads,
// size_t len, XPUStream stream);
#endif
} // end namespace framework
} // end namespace paddle
#endif
......@@ -15,39 +15,28 @@ limitations under the License. */
#pragma once
#include <thread>
#include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#include "hashtable.h" // NOLINT
#include "heter_resource.h" // NOLINT
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/place.h"
#include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP)
#include <xpu/runtime.h>
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
struct CustomGradMerger {
template <typename T>
CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
operator()(const T& a, const T& b) const {
T out;
out.slot = a.slot;
out.show = a.show + b.show;
out.clk = a.clk + b.clk;
out.lr_g = a.lr_g + b.lr_g;
for (int i = 0; i < MF_DIM; ++i) {
out.mf_g[i] = a.mf_g[i] + b.mf_g[i];
}
return out;
}
};
template <typename KeyType, typename ValType, typename GradType>
class HeterComm {
public:
......@@ -67,10 +56,21 @@ class HeterComm {
void show_one_table(int gpu_num);
int get_index_by_devid(int devid);
#if defined(PADDLE_WITH_CUDA)
template <typename Sgd>
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd); // NOLINT
#elif defined(PADDLE_WITH_XPU_KP)
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif
int log2i(int x);
template <typename DstPlace, typename SrcPlace, typename StreamType>
void memory_copy(DstPlace dst_place, void* dst, SrcPlace src_place,
const void* src, size_t count, StreamType stream = 0);
#if defined(PADDLE_WITH_CUDA)
template <typename Sgd>
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
size_t len, Sgd& sgd); // NOLINT
......@@ -85,8 +85,6 @@ class HeterComm {
int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);
int log2i(int x);
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
......@@ -94,6 +92,7 @@ class HeterComm {
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}
#endif
bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
......@@ -101,19 +100,19 @@ class HeterComm {
// void dump_to_cpu(int index);
void end_pass();
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }
void end_pass();
struct Node {
cudaStream_t in_stream;
cudaStream_t out_stream;
ppStream in_stream;
ppStream out_stream;
char* key_storage;
char* val_storage;
int sync;
int key_bytes_len;
int val_bytes_len;
int gpu_num;
int dev_num;
};
struct Path {
......@@ -133,7 +132,7 @@ class HeterComm {
alloc(size, true);
}
void alloc(int size, bool force = false) {
void alloc(size_t size, bool force = false) {
if (force || size > all_keys_mem->size()) {
all_keys_mem.reset();
all_grads_mem.reset();
......@@ -152,7 +151,11 @@ class HeterComm {
}
}
#if defined(PADDLE_WITH_CUDA)
platform::CUDAPlace place_;
#elif defined(PADDLE_WITH_XPU_KP)
platform::XPUPlace place_;
#endif
std::shared_ptr<memory::Allocation> all_keys_mem;
std::shared_ptr<memory::Allocation> all_grads_mem;
KeyType* all_keys;
......@@ -166,6 +169,33 @@ class HeterComm {
void init_path();
template <typename StreamType>
void sync_stream(const StreamType& stream) {
#if defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#elif defined(PADDLE_WITH_XPU_KP)
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(stream));
#endif
}
template <typename StreamType>
void create_stream(StreamType* stream) {
#if defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(stream));
#elif defined(PADDLE_WITH_XPU_KP)
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(stream));
#endif
}
template <typename StreamType>
void destroy_stream(StreamType stream) {
#if defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#elif defined(PADDLE_WITH_XPU_KP)
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(stream));
#endif
}
void create_storage(int start_index, int end_index, int keylen, int vallen);
void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
......@@ -182,15 +212,18 @@ class HeterComm {
int block_size_{256};
private:
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;
std::vector<LocalStorage> storage_;
CustomGradMerger merger_;
int topo_aware_{0};
int feanum_{1800 * 2048};
int multi_node_{0};
int node_size_;
#if defined(PADDLE_WITH_CUDA)
std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_;
int node_size_;
std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
#endif
};
} // end namespace framework
......
......@@ -13,115 +13,46 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include <queue>
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#endif
namespace paddle {
namespace framework {
template <typename T>
__global__ void fill_idx(T* idx, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
idx[i] = i;
}
}
template <typename T>
void show_tensor(T* input, size_t len, gpuStream_t stream, std::string name) {
T tmp[len]; // NOLINT
cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
std::cout << name;
for (int i = 0; i < len; ++i) {
std::cout << ":" << tmp[i];
}
std::cout << std::endl;
}
template <typename T>
__global__ void calc_shard_offset(T* idx, T* left, T* right, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len - 1) {
if (idx[i] != idx[i + 1]) {
right[idx[i]] = i;
left[idx[i + 1]] = i + 1;
}
}
if (i == 0) {
left[idx[i]] = i;
}
if (i == (len - 1)) {
right[idx[i]] = i;
}
}
template <typename KeyType, typename T>
__global__ void calc_shard_index(KeyType* d_keys, size_t len, T* shard_index,
int total_gpu) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
shard_index[i] = d_keys[i] % total_gpu;
}
}
template <typename KeyType, typename T>
__global__ void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx,
size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
}
}
template <typename KeyType, typename GradType, typename T>
__global__ void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads,
T* idx, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
d_shard_grads[i] = d_grads[idx[i]];
}
}
template <typename ValType, typename T>
__global__ void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_vals[idx[i]] = d_shard_vals[i];
}
}
template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
resource_ = resource;
storage_.resize(resource_->total_gpu());
for (int i = 0; i < resource_->total_gpu(); ++i) {
storage_.resize(resource_->total_device());
for (int i = 0; i < resource_->total_device(); ++i) {
#if defined(PADDLE_WITH_CUDA)
platform::CUDADeviceGuard guard(resource_->dev_id(i));
allocators_.push_back(std::make_shared<cub::CachingDeviceAllocator>(
8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT
#endif
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
if (multi_node_) {
storage_[i].init(feanum_, resource_->dev_id(i));
}
}
heter_comm_kernel_ = std::make_unique<HeterCommKernel>(block_size_);
init_path();
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::init_path() {
int total_gpu = resource_->total_gpu();
path_.resize(total_gpu);
int total_device = resource_->total_device();
path_.resize(total_device);
if (!topo_aware_) {
VLOG(0) << "init path without topo aware";
for (int i = 0; i < total_gpu; ++i) {
path_[i].resize(total_gpu);
for (int j = 0; j < total_gpu; ++j) {
for (int i = 0; i < total_device; ++i) {
path_[i].resize(total_device);
for (int j = 0; j < total_device; ++j) {
auto& nodes = path_[i][j].nodes_;
nodes.resize(1);
nodes[0].in_stream = resource_->comm_stream(i, j);
......@@ -129,17 +60,18 @@ void HeterComm<KeyType, ValType, GradType>::init_path() {
nodes[0].key_storage = NULL;
nodes[0].val_storage = NULL;
nodes[0].sync = 0;
nodes[0].gpu_num = j;
nodes[0].dev_num = j;
}
}
} else {
VLOG(0) << "init path with topo aware";
for (int i = 0; i < total_gpu; ++i) {
path_[i].resize(total_gpu);
for (int j = 0; j < total_gpu; ++j) {
for (int i = 0; i < total_device; ++i) {
path_[i].resize(total_device);
for (int j = 0; j < total_device; ++j) {
auto& nodes = path_[i][j].nodes_;
int from = resource_->dev_id(i);
int to = resource_->dev_id(j);
int transfer_id = i;
if (need_transfer(from, to)) {
transfer_id = resource_->get_index_by_devid(get_transfer_devid(from));
......@@ -150,7 +82,7 @@ void HeterComm<KeyType, ValType, GradType>::init_path() {
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 1;
node.gpu_num = transfer_id;
node.dev_num = transfer_id;
}
nodes.push_back(Node());
Node& node = nodes.back();
......@@ -159,148 +91,222 @@ void HeterComm<KeyType, ValType, GradType>::init_path() {
node.key_storage = NULL;
node.val_storage = NULL;
node.sync = 0;
node.gpu_num = j;
node.dev_num = j;
}
}
}
}
template <typename KeyType, typename ValType, typename GradType>
template <typename DstPlace, typename SrcPlace, typename StreamType>
void HeterComm<KeyType, ValType, GradType>::memory_copy(
DstPlace dst_place, void* dst, SrcPlace src_place, const void* src,
size_t count, StreamType stream) {
#if defined(PADDLE_WITH_CUDA)
cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream);
if (stream == 0) {
cudaStreamSynchronize(0);
}
#elif defined(PADDLE_WITH_XPU_KP)
memory::Copy(dst_place, dst, src_place, src, count);
#endif
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::create_storage(int start_index,
int end_index,
int keylen,
int vallen) {
#if defined(PADDLE_WITH_CUDA)
auto& allocator = allocators_[start_index];
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num));
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].dev_num));
allocator->DeviceAllocate(
resource_->dev_id(nodes[i].gpu_num),
resource_->dev_id(nodes[i].dev_num),
(void**)&(nodes[i].key_storage), // NOLINT
keylen, resource_->remote_stream(nodes[i].gpu_num, start_index));
keylen, resource_->remote_stream(nodes[i].dev_num, start_index));
allocator->DeviceAllocate(
resource_->dev_id(nodes[i].gpu_num),
resource_->dev_id(nodes[i].dev_num),
(void**)&(nodes[i].val_storage), // NOLINT
vallen, resource_->remote_stream(nodes[i].gpu_num, start_index));
vallen, resource_->remote_stream(nodes[i].dev_num, start_index));
nodes[i].key_bytes_len = keylen;
nodes[i].val_bytes_len = vallen;
}
#elif defined(PADDLE_WITH_XPU_KP)
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
platform::XPUDeviceGuard guard(resource_->dev_id(nodes[i].dev_num));
auto place = DevPlace(resource_->dev_id(nodes[i].dev_num));
auto node_keys_mem = memory::Alloc(place, keylen);
nodes[i].key_storage = reinterpret_cast<char*>(node_keys_mem->ptr());
auto node_vals_mem = memory::Alloc(place, vallen);
nodes[i].val_storage = reinterpret_cast<char*>(node_vals_mem->ptr());
nodes[i].key_bytes_len = keylen;
nodes[i].val_bytes_len = vallen;
}
#endif
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::destroy_storage(int start_index,
int end_index) {
#if defined(PADDLE_WITH_CUDA)
auto& allocator = allocators_[start_index];
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num));
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].dev_num));
allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num),
allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num),
nodes[i].key_storage);
allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num),
allocator->DeviceFree(resource_->dev_id(nodes[i].dev_num),
nodes[i].val_storage);
}
#endif
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(
int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key,
GradType* src_val) {
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
int num, int* h_left,
int* h_right,
KeyType* src_key,
GradType* src_val) {
int need_copy_val = 0;
if (src_val) {
need_copy_val = 1;
}
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
for (int i = 0; i < num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
int size = path_[start_index][i].nodes_.size();
// int size = path_[start_index][i].nodes_.size();
auto& node = path_[start_index][i].nodes_[0];
CopyTask t(&path_[start_index][i], 0);
que.push(t);
cudaMemcpyAsync(node.key_storage,
reinterpret_cast<char*>(src_key + h_left[i]),
node.key_bytes_len, cudaMemcpyDefault, node.in_stream);
auto src_dev_id = resource_->dev_id(start_index);
auto dst_dev_id = resource_->dev_id(i);
auto src_place = DevPlace(src_dev_id);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, node.key_storage, src_place,
reinterpret_cast<char*>(src_key + h_left[i]),
node.key_bytes_len, node.in_stream);
if (need_copy_val) {
cudaMemcpyAsync(node.val_storage,
reinterpret_cast<char*>(src_val + h_left[i]),
node.val_bytes_len, cudaMemcpyDefault, node.in_stream);
memory_copy(dst_place, node.val_storage, src_place,
reinterpret_cast<char*>(src_val + h_left[i]),
node.val_bytes_len, node.in_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
if (cur_task.path->nodes_[cur_task.step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream);
sync_stream(cur_task.path->nodes_[cur_task.step].in_stream);
}
if (cur_task.step != cur_task.path->nodes_.size() - 1) {
if (static_cast<size_t>(cur_task.step) !=
cur_task.path->nodes_.size() - 1) {
int cur_step = cur_task.step;
CopyTask c(cur_task.path, cur_step + 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage,
cur_task.path->nodes_[cur_step].key_storage,
cur_task.path->nodes_[cur_step + 1].key_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
auto src_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num);
auto dst_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step + 1].dev_num);
auto src_place = DevPlace(src_dev_id);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, cur_task.path->nodes_[cur_step + 1].key_storage,
src_place, cur_task.path->nodes_[cur_step].key_storage,
cur_task.path->nodes_[cur_step + 1].key_bytes_len,
cur_task.path->nodes_[cur_step + 1].in_stream);
if (need_copy_val) {
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step + 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
memory_copy(dst_place, cur_task.path->nodes_[cur_step + 1].val_storage,
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step + 1].val_bytes_len,
cur_task.path->nodes_[cur_step + 1].in_stream);
}
}
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(
int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val) {
void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index,
int num, int* h_left,
int* h_right,
ValType* src_val) {
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
for (int i = 0; i < num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
int cur_step = path_[start_index][i].nodes_.size() - 1;
auto& node = path_[start_index][i].nodes_[cur_step];
auto src_dev_id = resource_->dev_id(i);
auto src_place = DevPlace(src_dev_id);
if (cur_step == 0) {
cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[i]),
node.val_storage, node.val_bytes_len, cudaMemcpyDefault,
node.out_stream);
auto dst_dev_id = resource_->dev_id(start_index);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, reinterpret_cast<char*>(src_val + h_left[i]),
src_place, node.val_storage, node.val_bytes_len,
node.out_stream);
} else {
CopyTask t(&path_[start_index][i], cur_step - 1);
que.push(t);
cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage,
node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[cur_step - 1].out_stream);
auto dst_dev_id =
resource_->dev_id(path_[start_index][i].nodes_[cur_step - 1].dev_num);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place,
path_[start_index][i].nodes_[cur_step - 1].val_storage,
src_place, node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
path_[start_index][i].nodes_[cur_step - 1].out_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
int cur_step = cur_task.step;
if (cur_task.path->nodes_[cur_step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream);
sync_stream(cur_task.path->nodes_[cur_step].out_stream);
}
auto src_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num);
auto src_place = DevPlace(src_dev_id);
if (cur_step > 0) {
CopyTask c(cur_task.path, cur_step - 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step - 1].out_stream);
auto dst_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step - 1].dev_num);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, cur_task.path->nodes_[cur_step - 1].val_storage,
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cur_task.path->nodes_[cur_step - 1].out_stream);
} else if (cur_step == 0) {
int end_index = cur_task.path->nodes_.back().gpu_num;
cudaMemcpyAsync(reinterpret_cast<char*>(src_val + h_left[end_index]),
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step].out_stream);
int end_index = cur_task.path->nodes_.back().dev_num;
auto dst_dev_id = resource_->dev_id(end_index);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place,
reinterpret_cast<char*>(src_val + h_left[end_index]),
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step].val_bytes_len,
cur_task.path->nodes_[cur_step].out_stream);
}
}
}
......@@ -314,8 +320,8 @@ HeterComm<KeyType, ValType, GradType>::~HeterComm() {
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::show_one_table(int gpu_num) {
tables_[gpu_num]->show();
void HeterComm<KeyType, ValType, GradType>::show_one_table(int num) {
tables_[num]->show();
}
template <typename KeyType, typename ValType, typename GradType>
......@@ -333,24 +339,22 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
ValType* h_vals,
size_t len,
size_t chunk_size,
int stream_num) {
void HeterComm<KeyType, ValType, GradType>::build_ps(
int dev_num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num) {
if (len <= 0) {
return;
}
int dev_id = resource_->dev_id(num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
int dev_id = resource_->dev_id(dev_num);
std::vector<memory::allocation::AllocationPtr> d_key_bufs;
std::vector<memory::allocation::AllocationPtr> d_val_bufs;
gpuStream_t streams[stream_num]; // NOLINT
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
ppStream streams[stream_num]; // NOLINT
for (int i = 0; i < stream_num; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&(streams[i])));
create_stream(&(streams[i]));
auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType));
auto d_v_buf = memory::Alloc(place, chunk_size * sizeof(ValType));
d_key_bufs.push_back(std::move(d_k_buf));
......@@ -360,39 +364,48 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
int cur_len = 0;
int cur_stream = 0;
while (cur_len < len) {
while (static_cast<size_t>(cur_len) < len) {
cur_stream = cur_stream % stream_num;
auto cur_use_stream = streams[cur_stream];
#if defined(PADDLE_WITH_XPU_KP)
cur_use_stream = 0;
#endif
int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(d_key_bufs[cur_stream]->ptr(), h_keys + cur_len,
sizeof(KeyType) * tmp_len, cudaMemcpyHostToDevice,
streams[cur_stream]));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(d_val_bufs[cur_stream]->ptr(), h_vals + cur_len,
sizeof(ValType) * tmp_len, cudaMemcpyHostToDevice,
streams[cur_stream]));
tables_[num]->insert(
auto dst_place = place;
auto src_place = platform::CPUPlace();
memory_copy(
dst_place, reinterpret_cast<char*>(d_key_bufs[cur_stream]->ptr()),
src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream);
memory_copy(
dst_place, reinterpret_cast<char*>(d_val_bufs[cur_stream]->ptr()),
src_place, h_vals + cur_len, sizeof(ValType) * tmp_len, cur_use_stream);
tables_[dev_num]->insert(
reinterpret_cast<KeyType*>(d_key_bufs[cur_stream]->ptr()),
reinterpret_cast<ValType*>(d_val_bufs[cur_stream]->ptr()), tmp_len,
streams[cur_stream]);
cur_use_stream);
cur_stream += 1;
cur_len += tmp_len;
}
for (int i = 0; i < stream_num; ++i) {
cudaStreamSynchronize(streams[i]);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams[i]));
sync_stream(streams[i]);
destroy_stream(streams[i]);
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::merge_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int dev_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len) { // NOLINT
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int dev_id = resource_->dev_id(dev_num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0);
size_t temp_storage_bytes;
......@@ -403,48 +416,50 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(
GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr());
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads,
d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false));
heter_comm_kernel_->sort_pairs(NULL, temp_storage_bytes, d_keys,
d_merge_keys_ptr, d_grads, d_merge_grads_ptr,
len, 0, 8 * sizeof(KeyType), stream, false);
void* d_buff = NULL;
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
heter_comm_kernel_->sort_pairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false));
d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false);
temp_storage_bytes = 0;
auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int));
int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr());
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr,
d_grads, d_num_runs_out, merger_, len, stream, false));
heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, d_merge_keys_ptr,
d_keys, d_merge_grads_ptr, d_grads,
d_num_runs_out, len, stream, false);
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
heter_comm_kernel_->reduce_by_key(
d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys,
d_merge_grads_ptr, d_grads, d_num_runs_out, merger_, len, stream, false));
d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false);
cudaMemcpyAsync(&uniq_len, d_num_runs_out, sizeof(int),
cudaMemcpyDeviceToHost, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
auto dst_place = platform::CPUPlace();
auto src_place = place;
memory_copy(dst_place, &uniq_len, src_place, d_num_runs_out, sizeof(int),
stream);
sync_stream(stream);
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right,
int gpu_num) {
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int dev_num) {
int total_device = resource_->total_device();
int dev_id = resource_->dev_id(dev_num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0);
auto d_idx_tmp = memory::Alloc(place, len * sizeof(int));
int* d_idx_tmp_ptr = reinterpret_cast<int*>(d_idx_tmp->ptr());
......@@ -455,24 +470,28 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
auto d_shard_index_tmp = memory::Alloc(place, len * sizeof(int));
int* d_shard_index_tmp_ptr = reinterpret_cast<int*>(d_shard_index_tmp->ptr());
int grid_size = (len - 1) / block_size_ + 1;
fill_idx<<<grid_size, block_size_, 0, stream>>>(d_idx_tmp_ptr, len);
calc_shard_index<<<grid_size, block_size_, 0, stream>>>(
d_keys, len, d_shard_index_tmp_ptr, total_gpu);
// int grid_size = (len - 1) / block_size_ + 1;
heter_comm_kernel_->fill_idx(d_idx_tmp_ptr, len, stream);
heter_comm_kernel_->calc_shard_index(d_keys, len, d_shard_index_tmp_ptr,
total_device, stream);
size_t temp_storage_bytes;
const int num_bits = 1 + log2i(total_gpu);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
const int num_bits = 1 + log2i(total_device);
heter_comm_kernel_->sort_pairs(
NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr,
d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream));
d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream);
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
heter_comm_kernel_->sort_pairs(
d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr,
d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream));
calc_shard_offset<<<grid_size, block_size_, 0, stream>>>(d_shard_index_ptr,
left, right, len);
cudaStreamSynchronize(stream);
d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream);
heter_comm_kernel_->calc_shard_offset(d_shard_index_ptr, left, right, len,
total_device, stream);
sync_stream(stream);
}
template <typename KeyType, typename ValType, typename GradType>
......@@ -484,25 +503,43 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
return;
}
int total_gpu = resource_->total_gpu();
int total_device = resource_->total_device();
int dev_id = resource_->dev_id(num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(num, 0);
int grid_size = (len - 1) / block_size_ + 1;
// int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT
int h_left[total_device]; // NOLINT
int h_right[total_device]; // NOLINT
auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
auto d_left = memory::Alloc(place, total_device * sizeof(int));
auto d_right = memory::Alloc(place, total_device * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
//
#if defined(PADDLE_WITH_CUDA)
cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream);
#elif defined(PADDLE_WITH_XPU_KP)
// get XPUDeviceContext according to xpu place
paddle::platform::XPUDeviceContext xpu_dev_ctx(place);
auto xpu_context = xpu_dev_ctx.x_context();
int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r2,
XPUAPIErrorMsg[r2]));
#endif
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
......@@ -513,17 +550,20 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num);
fill_shard_key<<<grid_size, block_size_, 0, stream>>>(d_shard_keys_ptr,
d_keys, d_idx_ptr, len);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_keys, d_idx_ptr, len,
stream);
cudaStreamSynchronize(stream);
sync_stream(stream);
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
auto dst_place = platform::CPUPlace();
auto src_place = place;
for (int i = 0; i < total_gpu; ++i) {
memory_copy(dst_place, h_left, src_place, d_left_ptr,
total_device * sizeof(int), stream);
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
......@@ -532,47 +572,53 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
shard_len * sizeof(ValType));
}
walk_to_dest(num, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL);
for (int i = 0; i < total_gpu; ++i) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1) {
continue;
}
auto& node = path_[num][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
sync_stream(node.in_stream);
AnyDeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->RDLock();
tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<ValType*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i, num));
for (int i = 0; i < total_device; ++i) {
sync_stream(resource_->remote_stream(i, num));
if (h_left[i] == -1) {
continue;
}
tables_[i]->rwlock_->UNLock();
}
walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
walk_to_src(num, total_device, h_left, h_right, d_shard_vals_ptr);
for (int i = 0; i < total_gpu; ++i) {
for (int i = 0; i < total_device; ++i) {
auto& node = path_[num][i].nodes_.front();
cudaStreamSynchronize(node.out_stream);
sync_stream(node.out_stream);
}
fill_dvals<<<grid_size, block_size_, 0, stream>>>(d_shard_vals_ptr, d_vals,
d_idx_ptr, len);
cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) {
heter_comm_kernel_->fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
stream);
sync_stream(stream);
for (int i = 0; i < total_device; ++i) {
destroy_storage(num, i);
}
}
#if defined(PADDLE_WITH_CUDA)
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
KeyType* d_keys,
GradType* d_grads,
size_t len,
......@@ -581,23 +627,42 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
return;
}
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int total_device = resource_->total_device();
int dev_id = resource_->dev_id(dev_num);
int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0);
auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int h_left[total_device]; // NOLINT
int h_right[total_device]; // NOLINT
auto d_left = memory::Alloc(place, total_device * sizeof(int));
auto d_right = memory::Alloc(place, total_device * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
//
#if defined(PADDLE_WITH_CUDA)
cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream);
#elif defined(PADDLE_WITH_XPU_KP)
// get XPUDeviceContext according to xpu place
paddle::platform::XPUDeviceContext xpu_dev_ctx(place);
auto xpu_context = xpu_dev_ctx.x_context();
int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r2,
XPUAPIErrorMsg[r2]));
#endif
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
......@@ -608,61 +673,183 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
reinterpret_cast<GradType*>(d_shard_grads->ptr());
int uniq_len = len;
merge_grad(gpu_num, d_keys, d_grads, len, uniq_len);
merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
int grid_size = (uniq_len - 1) / block_size_ + 1;
// int grid_size = (uniq_len - 1) / block_size_ + 1;
split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
gpu_num);
dev_num);
fill_shard_grads<<<grid_size, block_size_, 0, stream>>>(
d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr,
uniq_len);
heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys,
d_shard_grads_ptr, d_grads, d_idx_ptr,
uniq_len, stream);
cudaStreamSynchronize(stream);
sync_stream(stream);
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
auto dst_place = platform::CPUPlace();
auto src_place = place;
memory_copy(dst_place, h_left, src_place, d_left_ptr,
total_device * sizeof(int), stream);
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);
for (int i = 0; i < total_gpu; ++i) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
create_storage(gpu_num, i, shard_len * sizeof(KeyType),
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType));
}
walk_to_dest(gpu_num, total_gpu, h_left, h_right, d_shard_keys_ptr,
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
for (int i = 0; i < total_gpu; ++i) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto& node = path_[gpu_num][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
auto& node = path_[dev_num][i].nodes_.back();
sync_stream(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
AnyDeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, gpu_num));
resource_->remote_stream(i, dev_num));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i, gpu_num));
for (int i = 0; i < total_device; ++i) {
sync_stream(resource_->remote_stream(i, dev_num));
if (h_left[i] != -1) {
tables_[i]->rwlock_->UNLock();
}
}
for (int i = 0; i < total_gpu; ++i) {
destroy_storage(gpu_num, i);
for (int i = 0; i < total_device; ++i) {
destroy_storage(dev_num, i);
}
}
#elif defined(PADDLE_WITH_XPU_KP)
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
KeyType* d_keys,
GradType* d_grads,
size_t len) {
if (len == 0) {
return;
}
int total_device = resource_->total_device();
int dev_id = resource_->dev_id(dev_num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0);
int h_left[total_device]; // NOLINT
int h_right[total_device]; // NOLINT
auto d_left = memory::Alloc(place, total_device * sizeof(int));
auto d_right = memory::Alloc(place, total_device * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
#if defined(PADDLE_WITH_CUDA)
cudaMemsetAsync(d_left_ptr, -1, total_device * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_device * sizeof(int), stream);
#elif defined(PADDLE_WITH_XPU_KP)
// get XPUDeviceContext according to xpu place
paddle::platform::XPUDeviceContext xpu_dev_ctx(place);
auto xpu_context = xpu_dev_ctx.x_context();
int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS,
platform::errors::External(
"XPU constant kernel return wrong value[%d %s]", r2,
XPUAPIErrorMsg[r2]));
#endif
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_shard_grads_ptr =
reinterpret_cast<GradType*>(d_shard_grads->ptr());
int uniq_len = len;
merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
// int grid_size = (uniq_len - 1) / block_size_ + 1;
split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
dev_num);
heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys,
d_shard_grads_ptr, d_grads, d_idx_ptr,
(long long)uniq_len, stream);
sync_stream(stream);
auto dst_place = platform::CPUPlace();
auto src_place = place;
memory_copy(dst_place, h_left, src_place, d_left_ptr,
total_device * sizeof(int), stream);
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType));
}
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto& node = path_[dev_num][i].nodes_.back();
sync_stream(node.in_stream);
AnyDeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, dev_num));
}
for (int i = 0; i < total_device; ++i) {
sync_stream(resource_->remote_stream(i, dev_num));
if (h_left[i] != -1) {
tables_[i]->rwlock_->UNLock();
}
}
for (int i = 0; i < total_device; ++i) {
destroy_storage(dev_num, i);
}
}
#endif
#if defined(PADDLE_WITH_CUDA)
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::update_one_table(
......@@ -705,7 +892,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node(
template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int total_gpu = resource_->total_gpu();
int total_gpu = resource_->total_device();
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
......@@ -725,10 +912,10 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
// allgather grad len
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather((const void*)(d_node_len + gpu_num),
(void*)d_node_len, 1, ncclInt, // NOLINT
nccl_inner_comm, stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, // NOLINT
ncclInt, // NOLINT
nccl_inner_comm, stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu,
......@@ -775,11 +962,12 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
fill_shard_grads<<<grid_size, block_size_, 0, stream>>>(
// int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
heter_comm_kernel_->fill_shard_grads(
storage.local_keys + merge_num, storage.all_keys + index,
storage.local_grads + merge_num, storage.all_grads + index,
d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1);
d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1,
stream);
merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1;
}
......@@ -848,19 +1036,21 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
return ret;
}
#endif
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::end_pass() {
int total_gpu = resource_->total_gpu();
int total_device = resource_->total_device();
std::vector<std::thread> threads;
auto dump_to_cpu_func = [this](int index) {
auto stream = resource_->local_stream(index, 0);
int dev_id = resource_->dev_id(index);
platform::CUDADeviceGuard guard(dev_id);
AnyDeviceGuard guard(dev_id);
tables_[index]->dump_to_cpu(dev_id, stream);
};
for (int i = 0; i < total_gpu; ++i) {
for (int i = 0; i < total_device; ++i) {
threads.push_back(std::thread(dump_to_cpu_func, i));
}
for (auto& t : threads) {
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
namespace paddle {
namespace framework {
#ifdef PADDLE_WITH_CUDA
struct GPUCustomGradMerger {
template <typename T>
CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
operator()(const T& a, const T& b) const {
T out;
out.slot = a.slot;
out.show = a.show + b.show;
out.clk = a.clk + b.clk;
out.lr_g = a.lr_g + b.lr_g;
for (int i = 0; i < MF_DIM; ++i) {
out.mf_g[i] = a.mf_g[i] + b.mf_g[i];
}
return out;
}
} gpu_merger;
template <typename T>
__global__ void fill_idx_kernel(T* idx, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
idx[i] = i;
}
}
// template <typename T>
// void show_tensor(T* input, size_t len, gpuStream_t stream, std::string
// name)
// {
// T tmp[len]; // NOLINT
// cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost,
// stream);
// cudaStreamSynchronize(stream);
// std::cout << name;
// for (int i = 0; i < len; ++i) {
// std::cout << ":" << tmp[i];
// }
// std::cout << std::endl;
//}
template <typename T>
__global__ void calc_shard_offset_kernel(T* idx, T* left, T* right,
size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len - 1) {
if (idx[i] != idx[i + 1]) {
right[idx[i]] = i;
left[idx[i + 1]] = i + 1;
}
}
if (i == 0) {
left[idx[i]] = i;
}
if (i == (len - 1)) {
right[idx[i]] = i;
}
}
template <typename KeyType, typename T>
__global__ void calc_shard_index_kernel(KeyType* d_keys, size_t len,
T* shard_index, int total_gpu) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
shard_index[i] = d_keys[i] % total_gpu;
}
}
template <typename KeyType, typename T>
__global__ void fill_shard_key_kernel(KeyType* d_shard_keys, KeyType* d_keys,
T* idx, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
}
}
template <typename KeyType, typename GradType, typename T>
__global__ void fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads, T* idx, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
d_shard_grads[i] = d_grads[idx[i]];
}
}
template <typename ValType, typename T>
__global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
T* idx, size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_vals[idx[i]] = d_shard_vals[i];
}
}
// cuda implemention of heter_comm_kernel.h
template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx, long long len,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
fill_idx_kernel<<<grid_size, block_size_, 0, stream>>>(idx, c_len);
}
template <typename T, typename StreamType>
void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right,
long long len, int total_devs,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
calc_shard_offset_kernel<<<grid_size, block_size_, 0, stream>>>(idx, left,
right, c_len);
}
template <typename KeyType, typename T, typename StreamType>
void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len,
T* shard_index, int total_gpu,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
calc_shard_index_kernel<<<grid_size, block_size_, 0, stream>>>(
d_keys, c_len, shard_index, total_gpu);
}
template <typename KeyType, typename T, typename StreamType>
void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys,
T* idx, long long len,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
fill_shard_key_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_keys, d_keys, idx, c_len);
}
template <typename KeyType, typename GradType, typename T, typename StreamType>
void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads, T* idx, long long len,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len);
}
template <typename ValType, typename T, typename StreamType>
void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
long long len, const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(d_shard_vals, d_vals,
idx, c_len);
}
template <typename KeyT, typename ValueT, typename StreamType>
void HeterCommKernel::sort_pairs(void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const KeyT* d_keys_in, // NOLINT
KeyT* d_keys_out, const ValueT* d_values_in,
ValueT* d_values_out, int num_items,
int begin_bit, int end_bit, StreamType stream,
bool debug_synchronous) {
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in,
d_values_out, num_items, begin_bit, end_bit, stream, debug_synchronous));
}
template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
typename ValuesInputIteratorT, typename AggregatesOutputIteratorT,
typename NumRunsOutputIteratorT, typename StreamType>
void HeterCommKernel::reduce_by_key(void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out,
int num_items, StreamType stream,
bool debug_synchronous) {
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
d_temp_storage, temp_storage_bytes, d_keys_in, d_unique_out, d_values_in,
d_aggregates_out, d_num_runs_out, gpu_merger, num_items, stream,
debug_synchronous));
}
template void HeterCommKernel::fill_idx<int, cudaStream_t>(
int* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::calc_shard_offset<int, cudaStream_t>(
int* idx, int* left, int* right, long long len, int total_devs,
const cudaStream_t& stream);
template void HeterCommKernel::calc_shard_index<
unsigned long, int, cudaStream_t>(unsigned long* d_keys, long long len,
int* shard_index, int total_devs,
const cudaStream_t& stream);
template void HeterCommKernel::fill_shard_key<unsigned long, int, cudaStream_t>(
unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len,
const cudaStream_t& stream);
template void HeterCommKernel::fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, cudaStream_t>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
const cudaStream_t& stream);
template void
HeterCommKernel::fill_dvals<paddle::framework::FeatureValue, int, cudaStream_t>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len,
const cudaStream_t& stream);
template void HeterCommKernel::sort_pairs<
unsigned long, paddle::framework::FeaturePushValue, cudaStream_t>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const unsigned long* d_keys_in, // NOLINT
unsigned long* d_keys_out,
const paddle::framework::FeaturePushValue* d_values_in,
paddle::framework::FeaturePushValue* d_values_out, int num_items,
int begin_bit, int end_bit, cudaStream_t stream, bool debug_synchronous);
template void HeterCommKernel::sort_pairs<int, int, cudaStream_t>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const int* d_keys_in, // NOLINT
int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items,
int begin_bit, int end_bit, cudaStream_t stream, bool debug_synchronous);
template void HeterCommKernel::reduce_by_key<
unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*,
paddle::framework::FeaturePushValue*, int*, cudaStream_t>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
unsigned long* d_keys_in, unsigned long* d_unique_out,
paddle::framework::FeaturePushValue* d_values_in,
paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out,
int num_items, cudaStream_t stream, bool debug_synchronous);
#endif
} // namespace framework
} // namespace paddle
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#if defined(PADDLE_WITH_CUDA)
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
#endif
namespace paddle {
namespace framework {
class HeterCommKernel {
public:
HeterCommKernel() {}
explicit HeterCommKernel(const int block_size) : block_size_(block_size) {}
template <typename T, typename StreamType>
void fill_idx(T* idx, long long len, const StreamType& stream);
template <typename T, typename StreamType>
void calc_shard_offset(T* idx, T* left, T* right, long long len,
int total_devs, const StreamType& stream);
template <typename KeyType, typename T, typename StreamType>
void calc_shard_index(KeyType* d_keys, long long len, T* shard_index,
int total_devs, const StreamType& stream);
template <typename KeyType, typename T, typename StreamType>
void fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys, T* idx,
long long len, const StreamType& stream);
template <typename KeyType, typename GradType, typename T,
typename StreamType>
void fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads, T* idx,
long long len, const StreamType& stream);
template <typename ValType, typename T, typename StreamType>
void fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx, long long len,
const StreamType& stream);
template <typename KeyT, typename ValueT, typename StreamType>
void sort_pairs(void* d_temp_storage, size_t& temp_storage_bytes, // NOLINT
const KeyT* d_keys_in, KeyT* d_keys_out,
const ValueT* d_values_in, ValueT* d_values_out,
int num_items, int begin_bit = 0,
int end_bit = sizeof(KeyT) * 8, StreamType stream = NULL,
bool debug_synchronous = false);
template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
typename ValuesInputIteratorT, typename AggregatesOutputIteratorT,
typename NumRunsOutputIteratorT, typename StreamType>
void reduce_by_key(void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out, int num_items,
StreamType stream = NULL, bool debug_synchronous = false);
private:
int block_size_{256};
};
} // end namespace framework
} // end namespace paddle
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#if defined(PADDLE_WITH_XPU_KP)
#include <xpu/runtime.h>
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd.h"
#endif
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_XPU_KP)
struct XPUCustomGradMerger {
template <typename T>
__device__ T operator()(const T& a, const T& b) const {
T out;
out.slot = a.slot;
out.show = a.show + b.show;
out.clk = a.clk + b.clk;
out.lr_g = a.lr_g + b.lr_g;
for (int i = 0; i < MF_DIM; ++i) {
out.mf_g[i] = a.mf_g[i] + b.mf_g[i];
}
return out;
}
} xpu_merger;
template <typename T>
__global__ void fill_idx_kernel(T* idx, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 1024;
__local__ T local_idx[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
int read_len = min(len_per_loop, len - i);
for (int k = 0; k < read_len; k++) {
int real_idx = i + k;
local_idx[k] = real_idx;
}
LM2GM(local_idx, idx + i, read_len * sizeof(T));
}
}
template <typename T>
__global__ void calc_shard_offset_kernel(T* idx, T* left, T* right,
long long len, const int total_xpu) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 1024;
__local__ T local_idx[buf_size];
__local__ T local_left[total_xpu];
__local__ T local_right[total_xpu];
for (int i = 0; i < total_xpu; i++) {
local_left[i] = -1;
local_right[i] = -1;
}
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
// read batch from GM will boost performance
int read_len = min(len_per_loop, len - i);
GM2LM(idx + i, local_idx, read_len * sizeof(T));
for (int k = 0; k < read_len; k++) {
if (local_idx[k] != local_idx[k + 1]) {
int real_idx = i + k;
local_right[local_idx[k]] = real_idx;
local_left[local_idx[k + 1]] = real_idx + 1;
}
}
if (i == 0) {
local_left[local_idx[i]] = i;
}
if (i + read_len == len) {
local_right[local_idx[len - 1]] = len - 1;
}
}
// to be optimized: call LM2GM too frequently
// all_reduce between threads to get global left & global right && LM2GM
for (int i = 0; i < total_xpu; i++) {
if (local_left[i] != -1) LM2GM(local_left + i, left + i, sizeof(T));
if (local_right[i] != -1) LM2GM(local_right + i, right + i, sizeof(T));
}
}
template <typename KeyType, typename T>
__global__ void calc_shard_index_kernel(KeyType* d_keys, long long len,
T* shard_index, int total_xpu) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 512;
__local__ KeyType local_keys[buf_size];
__local__ T local_shard_index[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
// read batch from GM will boost performance
int read_len = min(len_per_loop, len - i);
GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType));
for (int k = 0; k < read_len; k++) {
local_shard_index[k] = local_keys[k] % total_xpu;
}
LM2GM(local_shard_index, shard_index + i, read_len * sizeof(T));
}
}
template <typename KeyType, typename T>
__global__ void fill_shard_key_kernel(KeyType* d_shard_keys, KeyType* d_keys,
T* idx, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 400;
__local__ KeyType local_keys[buf_size];
__local__ KeyType local_shard_keys[buf_size];
__local__ T local_idx[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
// read batch from GM will boost performance
int read_len = min(len_per_loop, len - i);
GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType));
GM2LM(idx + i, local_idx, read_len * sizeof(T));
for (int k = 0; k < read_len; k++) {
local_shard_keys[k] = local_keys[local_idx[k]];
}
LM2GM(local_shard_keys, d_shard_keys + i, read_len * sizeof(KeyType));
}
}
// local mem too large, cause compile error
template <typename KeyType, typename GradType, typename T>
__global__ void fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads, T* idx,
long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 100;
__local__ KeyType local_keys[buf_size];
__local__ GradType local_grads[buf_size];
__local__ KeyType local_shard_keys[buf_size];
__local__ GradType local_shard_grads[buf_size];
__local__ T local_idx[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
// read batch from GM will boost performance
int read_len = min(len_per_loop, len - i);
GM2LM(d_keys + i, local_keys, read_len * sizeof(KeyType));
GM2LM(d_grads + i, local_grads, read_len * sizeof(GradType));
GM2LM(idx + i, local_idx, read_len * sizeof(T));
for (int k = 0; k < read_len; k++) {
local_shard_keys[k] = local_keys[local_idx[k]];
local_shard_grads[k] = local_grads[local_idx[k]];
}
LM2GM(local_shard_keys, d_shard_keys + i, read_len * sizeof(KeyType));
LM2GM(local_shard_grads, d_shard_grads + i, read_len * sizeof(GradType));
}
}
template <typename ValType, typename T>
__global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
T* idx, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
const int buf_size = 50;
__local__ ValType local_vals[buf_size];
__local__ ValType local_shard_vals[buf_size];
__local__ T local_idx[buf_size];
int len_per_loop = min(buf_size, roundup_div(len, nthreads));
for (int i = thread_id * len_per_loop; i < len;
i += nthreads * len_per_loop) {
// read batch from GM will boost performance
int read_len = min(len_per_loop, len - i);
GM2LM(idx + i, local_idx, read_len * sizeof(T));
GM2LM(d_shard_vals + i, local_shard_vals, read_len * sizeof(ValType));
for (int k = 0; k < read_len; k++) {
local_vals[local_idx[k]] = local_shard_vals[k];
}
LM2GM(local_vals, d_vals + i, read_len * sizeof(ValType));
}
}
// xpu implementation of heter_comm_kernel.h
template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx, long long len,
const StreamType& stream) {
fill_idx_kernel<T><<<4, 64, stream>>>(idx, len);
}
template <typename T, typename StreamType>
void HeterCommKernel::calc_shard_offset(T* idx, T* left, T* right,
long long len, int total_devs,
const StreamType& stream) {
calc_shard_offset_kernel<T><<<4, 64, stream>>>(idx, left, right, len,
total_devs);
}
template <typename KeyType, typename T, typename StreamType>
void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len,
T* shard_index, int total_devs,
const StreamType& stream) {
calc_shard_index_kernel<KeyType, T><<<4, 64, stream>>>(
d_keys, len, shard_index, total_devs);
}
template <typename KeyType, typename T, typename StreamType>
void HeterCommKernel::fill_shard_key(KeyType* d_shard_keys, KeyType* d_keys,
T* idx, long long len,
const StreamType& stream) {
fill_shard_key_kernel<KeyType, T><<<4, 64, stream>>>(d_shard_keys, d_keys,
idx, len);
}
template <typename KeyType, typename GradType, typename T, typename StreamType>
void HeterCommKernel::fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads,
GradType* d_grads, T* idx, long long len,
const StreamType& stream) {
fill_shard_grads_kernel<KeyType, GradType, T><<<4, 64, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, len);
}
template <typename ValType, typename T, typename StreamType>
void HeterCommKernel::fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
long long len, const StreamType& stream) {
fill_dvals_kernel<ValType, T><<<4, 64, stream>>>(d_shard_vals, d_vals, idx,
len);
}
template <typename KeyT, typename ValueT, typename StreamType>
void HeterCommKernel::sort_pairs(void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const KeyT* d_keys_in, // NOLINT
KeyT* d_keys_out, const ValueT* d_values_in,
ValueT* d_values_out, int num_items,
int begin_bit, int end_bit, StreamType stream,
bool debug_synchronous) {}
template <typename KeysInputIteratorT, typename UniqueOutputIteratorT,
void HeterCommKernel::reduce_by_key(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out, int num_items,
StreamType stream, bool debug_synchronous) {}
template void HeterCommKernel::fill_idx<int, XPUStream>(
int* idx, long long len, const XPUStream& stream);
template void HeterCommKernel::calc_shard_offset<int, XPUStream>(
int* idx, int* left, int* right, long long len, int total_devs,
const XPUStream& stream);
template void HeterCommKernel::calc_shard_index<unsigned long, int, XPUStream>(
unsigned long* d_keys, long long len, int* shard_index, int total_devs,
const XPUStream& stream);
template void HeterCommKernel::fill_shard_key<unsigned long, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len,
const XPUStream& stream);
template void HeterCommKernel::fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, XPUStream>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
const XPUStream& stream);
template void
HeterCommKernel::fill_dvals<paddle::framework::FeatureValue, int, XPUStream>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len,
const XPUStream& stream);
template void HeterCommKernel::sort_pairs<
unsigned long, paddle::framework::FeaturePushValue, XPUStream>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const unsigned long* d_keys_in, // NOLINT
unsigned long* d_keys_out,
const paddle::framework::FeaturePushValue* d_values_in,
paddle::framework::FeaturePushValue* d_values_out, int num_items,
int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous);
template void HeterCommKernel::sort_pairs<int, int, XPUStream>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
const int* d_keys_in, // NOLINT
int* d_keys_out, const int* d_values_in, int* d_values_out, int num_items,
int begin_bit, int end_bit, XPUStream stream, bool debug_synchronous);
template void HeterCommKernel::reduce_by_key<
unsigned long*, unsigned long*, paddle::framework::FeaturePushValue*,
paddle::framework::FeaturePushValue*, int*, XPUStream>(
void* d_temp_storage,
size_t& temp_storage_bytes, // NOLINT
unsigned long* d_keys_in, unsigned long* d_unique_out,
paddle::framework::FeaturePushValue* d_values_in,
paddle::framework::FeaturePushValue* d_aggregates_out,
int* d_num_runs_out int num_items, XPUStream stream,
bool debug_synchronous);
#endif
} // end namespace framework
} // end namespace paddle
#endif
......@@ -29,7 +29,9 @@ HeterPs::HeterPs(size_t capacity, std::shared_ptr<HeterPsResource> resource) {
comm_ =
std::make_shared<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>>(
capacity, resource);
#if defined(PADDLE_WITH_CUDA)
opt_ = Optimizer<FeatureValue, FeaturePushValue>();
#endif
}
HeterPs::~HeterPs() {}
......@@ -54,15 +56,21 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
void HeterPs::push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) {
#if defined(PADDLE_WITH_CUDA)
comm_->push_sparse(num, d_keys, d_grads, len, opt_);
#elif defined(PADDLE_WITH_XPU_KP)
comm_->push_sparse(num, d_keys, d_grads, len);
#endif
// comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
}
#if defined(PADDLE_WITH_CUDA)
void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
}
#endif
} // end namespace framework
} // end namespace paddle
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#endif
#ifdef PADDLE_WITH_HETERPS
......@@ -35,9 +37,13 @@ class HeterPs : public HeterPsBase {
size_t len) override;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) override;
#if defined(PADDLE_WITH_CUDA)
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) override;
#endif
virtual void end_pass() override;
virtual int get_index_by_devid(int devid) override;
virtual void show_one_table(int gpu_num) override;
......@@ -46,7 +52,9 @@ class HeterPs : public HeterPsBase {
private:
std::shared_ptr<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>> comm_;
#if defined(PADDLE_WITH_CUDA)
Optimizer<FeatureValue, FeaturePushValue> opt_;
#endif
};
} // end namespace framework
......
......@@ -35,9 +35,11 @@ class HeterPsBase {
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0;
#if defined(PADDLE_WITH_CUDA)
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
#endif
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num, FeatureKey* d_keys,
......
......@@ -13,12 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_HETERPS
#include "heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#endif
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_CUDA)
GPUResource::GPUResource(std::vector<int>& dev_ids, int index) {
index_ = index;
dev_ids_ = dev_ids;
......@@ -52,7 +61,41 @@ GPUResource::~GPUResource() {
}
}
#elif defined(PADDLE_WITH_XPU_KP)
XPUResource::XPUResource(std::vector<int>& dev_ids, int index) {
index_ = index;
dev_ids_ = dev_ids;
dev_id_ = dev_ids_[index];
platform::XPUDeviceGuard guard(dev_id_);
local_streams_.resize(dev_ids_.size());
comm_streams_.resize(dev_ids_.size(), NULL);
remote_streams_.resize(dev_ids_.size());
for (size_t i = 0; i < dev_ids_.size(); ++i) {
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&local_streams_[i]));
// PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_streams_[i]));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&remote_streams_[i]));
}
}
XPUResource::~XPUResource() {
platform::XPUDeviceGuard guard(dev_id_);
for (size_t i = 0; i < local_streams_.size(); ++i) {
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(local_streams_[i]));
}
// for (size_t i = 0; i < comm_streams_.size(); ++i) {
// PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(comm_streams_[i]));
// }
for (size_t i = 0; i < remote_streams_.size(); ++i) {
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(remote_streams_[i]));
}
}
#endif
void HeterPsResource::enable_p2p() {
#if defined(PADDLE_WITH_CUDA)
for (size_t i = 0; i < dev_ids_.size(); ++i) {
platform::CUDADeviceGuard guard(dev_ids_[i]);
for (size_t j = 0; j < dev_ids_.size(); ++j) {
......@@ -72,28 +115,28 @@ void HeterPsResource::enable_p2p() {
}
}
}
#endif
}
HeterPsResource::HeterPsResource(const std::vector<int>& dev_ids) {
dev_ids_ = dev_ids;
for (size_t i = 0; i < dev_ids_.size(); ++i) {
std::shared_ptr<GPUResource> resource =
std::make_shared<GPUResource>(dev_ids_, i);
std::shared_ptr<DevResource> resource =
std::make_shared<DevResource>(dev_ids_, i);
resources_.push_back(resource);
devid_2_index_[dev_ids_[i]] = i;
}
}
cudaStream_t HeterPsResource::comm_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->comm_stream(stream_num);
ppStream HeterPsResource::comm_stream(int dev_num, int stream_num) {
return resources_[dev_num]->comm_stream(stream_num);
}
cudaStream_t HeterPsResource::local_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->local_stream(stream_num);
ppStream HeterPsResource::local_stream(int dev_num, int stream_num) {
return resources_[dev_num]->local_stream(stream_num);
}
cudaStream_t HeterPsResource::remote_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->remote_stream(stream_num);
ppStream HeterPsResource::remote_stream(int dev_num, int stream_num) {
return resources_[dev_num]->remote_stream(stream_num);
}
int HeterPsResource::dev_id(int num) { return dev_ids_[num]; }
......@@ -102,7 +145,7 @@ int HeterPsResource::get_index_by_devid(int devid) {
return devid_2_index_[devid];
}
int HeterPsResource::total_gpu() { return dev_ids_.size(); }
int HeterPsResource::total_device() { return dev_ids_.size(); }
void HeterPsResource::set_multi_mf(int multi_mf_dim, int max_mf_dim) {
multi_mf_dim_ = multi_mf_dim;
......
......@@ -17,7 +17,16 @@ limitations under the License. */
#include <map>
#include <memory>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include <xpu/runtime.h> // NOLINT
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS
......@@ -25,9 +34,16 @@ limitations under the License. */
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_CUDA)
using ppStream = cudaStream_t;
#elif defined(PADDLE_WITH_XPU_KP)
using ppStream = XPUStream;
#endif
#if defined(PADDLE_WITH_CUDA)
class GPUResource {
public:
GPUResource(std::vector<int>& device_id, int index);
GPUResource(std::vector<int>& device_id, int index); // NOLINT
virtual ~GPUResource();
GPUResource(const GPUResource&) = delete;
GPUResource& operator=(const GPUResource&) = delete;
......@@ -45,23 +61,55 @@ class GPUResource {
std::vector<gpuStream_t> local_streams_;
std::vector<gpuStream_t> comm_streams_;
};
#elif defined(PADDLE_WITH_XPU_KP)
class XPUResource {
public:
XPUResource(std::vector<int>& device_id, int index); // NOLINT
virtual ~XPUResource();
XPUResource(const XPUResource&) = delete;
XPUResource& operator=(const XPUResource&) = delete;
int dev_id() const { return dev_id_; }
int index() const { return index_; }
XPUStream local_stream(int num) { return local_streams_[num]; }
XPUStream remote_stream(int num) { return remote_streams_[num]; }
XPUStream comm_stream(int num) { return comm_streams_[num]; }
int dev_id_;
int index_;
std::vector<int> dev_ids_;
std::vector<XPUStream> remote_streams_;
std::vector<XPUStream> local_streams_;
std::vector<XPUStream> comm_streams_;
};
#endif
#if defined(PADDLE_WITH_CUDA)
using DevResource = GPUResource;
using DevPlace = platform::CUDAPlace;
using AnyDeviceGuard = platform::CUDADeviceGuard;
#elif defined(PADDLE_WITH_XPU_KP)
using DevResource = XPUResource;
using DevPlace = platform::XPUPlace;
using AnyDeviceGuard = platform::XPUDeviceGuard;
#endif
class HeterPsResource {
public:
HeterPsResource(const std::vector<int>& dev_ids);
explicit HeterPsResource(const std::vector<int>& dev_ids);
HeterPsResource(const HeterPsResource&) = delete;
HeterPsResource& operator=(const HeterPsResource&) = delete;
virtual ~HeterPsResource() {}
void enable_p2p();
int total_gpu();
int total_device();
int get_index_by_devid(int devid);
int dev_id(int num);
void set_multi_mf(int multi_mf_dim, int max_mf_dim);
gpuStream_t local_stream(int gpu_num, int stream_num);
gpuStream_t remote_stream(int gpu_num, int stream_num);
gpuStream_t comm_stream(int gpu_num, int stream_num);
ppStream local_stream(int dev_num, int stream_num);
ppStream remote_stream(int dev_num, int stream_num);
ppStream comm_stream(int dev_num, int stream_num);
std::vector<std::shared_ptr<GPUResource>> resources_;
std::vector<std::shared_ptr<DevResource>> resources_;
std::vector<int> dev_ids_;
std::map<int, int> devid_2_index_;
int multi_mf_dim_{0};
......
......@@ -18,6 +18,7 @@ limitations under the License. */
// #include
// "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include <iostream>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/fleet/heter_ps/cudf/managed.cuh"
namespace paddle {
......@@ -111,3 +112,4 @@ class HBMMemoryPool : public managed {
} // end namespace framework
} // end namespace paddle
#endif
#endif
......@@ -13,16 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#if defined(PADDLE_WITH_CUDA)
#include <curand_kernel.h>
#endif
#include <vector>
#include "optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
namespace paddle {
namespace framework {
#if defined(PADDLE_WITH_CUDA)
template <typename ValType, typename GradType>
class Optimizer {
public:
......@@ -32,7 +35,8 @@ class Optimizer {
void initialize() {}
__device__ void update_lr(float& w, float& g2sum, float g, float scale) {
__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
float scale) {
double add_g2sum = 0;
double ratio = optimizer_config::learning_rate *
sqrt(optimizer_config::initial_g2sum /
......@@ -49,8 +53,8 @@ class Optimizer {
g2sum += add_g2sum;
}
__device__ void update_mf(int n, float* w, float& g2sum, const float* g,
float scale) {
__device__ void update_mf(int n, float* w, float& g2sum, // NOLINT
const float* g, float scale) {
double add_g2sum = 0;
double ratio = optimizer_config::mf_learning_rate *
sqrt(optimizer_config::mf_initial_g2sum /
......@@ -69,7 +73,8 @@ class Optimizer {
g2sum += add_g2sum / n;
}
__device__ void update_value(ValType& val, const GradType& grad) {
__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
val.slot = grad.slot;
val.show += grad.show;
val.clk += grad.clk;
......@@ -132,6 +137,7 @@ class Optimizer {
}
};
#endif
} // end namespace framework
} // end namespace paddle
#endif
......@@ -14,8 +14,16 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_XPU_KP)
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#endif
namespace optimizer_config {
#if defined(PADDLE_WITH_CUDA)
__constant__ float nonclk_coeff = 0.1;
__constant__ float clk_coeff = 1;
......@@ -31,4 +39,24 @@ __constant__ float mf_initial_g2sum = 3.0;
__constant__ float mf_initial_range = 1e-4;
__constant__ float mf_min_bound = -10;
__constant__ float mf_max_bound = 10;
}
#elif defined(PADDLE_WITH_XPU_KP)
_global_ptr_ float* nonclk_coeff;
_global_ptr_ float* clk_coeff;
_global_ptr_ float* min_bound;
_global_ptr_ float* max_bound;
_global_ptr_ float* learning_rate;
_global_ptr_ float* initial_g2sum;
_global_ptr_ float* initial_range;
_global_ptr_ float* mf_create_thresholds;
_global_ptr_ float* mf_learning_rate;
_global_ptr_ float* mf_initial_g2sum;
_global_ptr_ float* mf_initial_range;
_global_ptr_ float* mf_min_bound;
_global_ptr_ float* mf_max_bound;
#endif
} // namespace optimizer_config
......@@ -121,7 +121,7 @@ class PSGPUWrapper {
is_initialized_ = true;
resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p();
keys_tensor.resize(resource_->total_gpu());
keys_tensor.resize(resource_->total_device());
#ifdef PADDLE_WITH_GLOO
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Size() > 1) {
......@@ -287,8 +287,8 @@ class PSGPUWrapper {
for (size_t i = 0; i < num_of_dim; i++) {
dim_index_map[index_dim_vec_[i]] = i;
}
hbm_pools_.resize(resource_->total_gpu() * num_of_dim);
mem_pools_.resize(resource_->total_gpu() * num_of_dim);
hbm_pools_.resize(resource_->total_device() * num_of_dim);
mem_pools_.resize(resource_->total_device() * num_of_dim);
max_mf_dim_ = index_dim_vec_.back();
multi_mf_dim_ = (dim_index_map.size() >= 1) ? dim_index_map.size() : 0;
resource_->set_multi_mf(multi_mf_dim_, max_mf_dim_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册