未验证 提交 ff4612a3 编写于 作者: T Thunderbrook 提交者: GitHub

[Cherry pick] cherry-pick #31102 #30750 #30626 (#31336)

* solve build gpu task core (#30626)

* build gpu task core

* format

* dump to cpu (#30750)

* dump to cpu

* format

* format

* format

* support multi node in heterps (#31102)

* push multi node

* multi node

* MultiThread

* remove log

* solve bug in 30829

* optimizer
上级 a891032f
......@@ -27,14 +27,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/io/fs.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
......
......@@ -29,10 +29,20 @@ namespace framework {
class HeterContext {
public:
~HeterContext() {
for (size_t i = 0; i < mutex_.size(); ++i) {
delete mutex_[i];
}
mutex_.clear();
}
Scope* scope_{nullptr};
std::vector<std::vector<FeatureKey>> feature_keys_;
std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> value_ptr_;
std::vector<std::vector<FeatureValue>> feature_values_;
std::vector<std::vector<FeatureValue>> device_values_;
std::vector<std::vector<FeatureKey>> device_keys_;
std::vector<std::mutex*> mutex_;
uint32_t shard_num_ = 37;
uint64_t size() {
uint64_t total_size = 0;
for (auto& keys : feature_keys_) {
......@@ -40,6 +50,48 @@ class HeterContext {
}
return total_size;
}
void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; }
uint32_t ShardNum() { return shard_num_; }
void init(int shard_num, int device_num) {
shard_num_ = shard_num;
feature_keys_.resize(shard_num_);
value_ptr_.resize(shard_num_);
device_values_.resize(device_num);
device_keys_.resize(device_num);
mutex_.resize(device_num);
for (size_t i = 0; i < mutex_.size(); ++i) {
mutex_[i] = new std::mutex();
}
}
void batch_add_keys(const std::vector<std::vector<uint64_t>>& thread_keys) {
assert(thread_keys.size() == feature_keys_.size());
for (uint32_t i = 0; i < shard_num_; i++) {
int idx = 0;
idx = feature_keys_[i].size();
feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size());
for (uint64_t j = 0; j < thread_keys[i].size(); j++) {
feature_keys_[i][idx + j] = thread_keys[i][j];
}
}
}
void UniqueKeys() {
std::vector<std::thread> threads;
auto unique_func = [this](int i) {
auto& cur_keys = feature_keys_[i];
std::sort(cur_keys.begin(), cur_keys.end());
std::vector<FeatureKey>::iterator it;
it = std::unique(cur_keys.begin(), cur_keys.end());
cur_keys.resize(std::distance(cur_keys.begin(), it));
};
for (uint32_t i = 0; i < shard_num_; i++) {
threads.push_back(std::thread(unique_func, i));
}
for (std::thread& t : threads) {
t.join();
}
}
};
} // end namespace framework
......
......@@ -33,6 +33,7 @@ struct FeatureValue {
float lr_g2sum;
int mf_size;
float mf[MF_DIM + 1];
uint64_t cpu_ptr;
friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
......
......@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <limits>
#include <memory>
#include <vector>
#include "common_value.h" // NOLINT
#include "thrust/pair.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
......@@ -47,6 +49,7 @@ class HashTable {
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
cudaStream_t stream);
void show();
void dump_to_cpu(int devid, cudaStream_t stream);
template <typename GradType, typename Sgd>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
......@@ -60,5 +63,5 @@ class HashTable {
};
} // end namespace framework
} // end namespace paddle
#include "hashtable.tpp"
#include "hashtable_inl.h"
#endif
......@@ -108,6 +108,41 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals, len);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
container_->prefetch(cudaCpuDeviceId, stream);
size_t num = container_->size();
KeyType unuse_key = std::numeric_limits<KeyType>::max();
thrust::pair<KeyType, ValType>* kv = container_->data();
for (size_t i = 0; i < num; ++i) {
if (kv[i].first == unuse_key) {
continue;
}
ValType& gpu_val = kv[i].second;
auto* downpour_value =
(paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
int downpour_value_size = downpour_value->size();
if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
downpour_value->resize(gpu_val.mf_size + downpour_value_size);
}
float* cpu_val = downpour_value->data();
cpu_val[0] = 0;
cpu_val[1] = gpu_val.delta_score;
cpu_val[2] = gpu_val.show;
cpu_val[3] = gpu_val.clk;
cpu_val[4] = gpu_val.lr;
cpu_val[5] = gpu_val.lr_g2sum;
cpu_val[6] = gpu_val.slot;
if (gpu_val.mf_size > 0) {
for (int x = 0; x < gpu_val.mf_size; x++) {
cpu_val[x + 7] = gpu_val.mf[x];
}
}
}
container_->prefetch(devid, stream);
}
template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
......
......@@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thread>
#include <vector>
#include "cub/cub.cuh"
#include "hashtable.h"
#include "heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/place.h"
#include "thrust/pair.h"
......@@ -67,11 +69,38 @@ class HeterComm {
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);
template <typename Sgd>
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
size_t len, Sgd& sgd);
template <typename Sgd>
void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);
int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);
int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);
int log2i(int x);
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
nccl_inner_comms_ = inner_comms;
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}
bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
}
// void dump_to_cpu(int index);
void end_pass();
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }
struct Node {
......@@ -89,6 +118,44 @@ class HeterComm {
std::vector<Node> nodes_;
};
struct LocalStorage {
LocalStorage() {}
void init(int size, int dev_id) {
place_ = platform::CUDAPlace(dev_id);
alloc(size, true);
}
void alloc(int size, bool force = false) {
if (force || size > all_keys_mem->size()) {
all_keys_mem.reset();
all_grads_mem.reset();
all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
}
if (force || size > local_keys_mem->size()) {
local_keys_mem.reset();
local_grads_mem.reset();
local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
}
}
platform::CUDAPlace place_;
std::shared_ptr<memory::Allocation> all_keys_mem;
std::shared_ptr<memory::Allocation> all_grads_mem;
KeyType* all_keys;
GradType* all_grads;
std::shared_ptr<memory::Allocation> local_keys_mem;
std::shared_ptr<memory::Allocation> local_grads_mem;
KeyType* local_keys;
GradType* local_grads;
};
void init_path();
void create_storage(
int start_index, int end_index, int keylen, int vallen,
......@@ -106,9 +173,15 @@ class HeterComm {
CustomGradMerger merger_;
int topo_aware_{1};
std::vector<std::vector<Path>> path_;
std::vector<LocalStorage> storage_;
int feanum_{1800 * 2048};
int multi_node_{1};
std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_;
int node_size_;
};
} // end namespace framework
} // end namespace paddle
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
#endif
......@@ -95,10 +95,14 @@ template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
resource_ = resource;
storage_.resize(resource_->total_gpu());
for (int i = 0; i < resource_->total_gpu(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
if (multi_node_) {
storage_[i].init(feanum_, resource_->dev_id(i));
}
}
init_path();
}
......@@ -595,6 +599,214 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
}
}
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::update_one_table(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
if (len == 0) {
return;
}
int dev_id = resource_->dev_id(gpu_num);
platform::CUDADeviceGuard guard(dev_id);
tables_[gpu_num]->update(d_keys, d_grads, len, sgd,
resource_->remote_stream(gpu_num));
cudaStreamSynchronize(resource_->remote_stream(gpu_num));
}
template <typename KeyType, typename ValType, typename GradType>
template <typename Sgd>
void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
if (len == 0) {
return;
}
int uniq_len = len;
merge_grad(gpu_num, d_keys, d_grads, len, uniq_len);
uniq_len = gather_one_node_grad(gpu_num, d_keys, d_grads, uniq_len);
uniq_len = gather_multi_node_grad(gpu_num, storage_[gpu_num].local_keys,
storage_[gpu_num].local_grads, uniq_len);
update_one_table(gpu_num, storage_[gpu_num].local_keys,
storage_[gpu_num].local_grads, uniq_len, sgd);
}
template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int max_size = 0;
ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num];
// alloc for size
int h_node_len[total_gpu];
auto d_node_len_mem = memory::AllocShared(place, total_gpu * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[gpu_num] = len;
cudaMemcpy(d_node_len + gpu_num, h_node_len + gpu_num, sizeof(int),
cudaMemcpyHostToDevice);
// allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt,
nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu,
cudaMemcpyDeviceToHost);
for (int i = 0; i < total_gpu; ++i) {
if (h_node_len[i] > max_size) {
max_size = h_node_len[i];
}
}
storage.alloc(max_size * total_gpu);
// allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
int h_left[total_gpu];
int h_right[total_gpu];
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int));
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
int merge_num = 0;
for (int i = 0; i < total_gpu; ++i) {
int index = i * max_size;
auto d_idx = memory::AllocShared(place, h_node_len[i] * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int));
cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int));
split_input_to_shard(storage.all_keys + index, d_idx_ptr, h_node_len[i],
d_left_ptr, d_right_ptr, gpu_num);
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
fill_shard_grads<<<grid_size, block_size_, 0, stream>>>(
storage.local_keys + merge_num, storage.all_keys + index,
storage.local_grads + merge_num, storage.all_grads + index,
d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1);
merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1;
}
int ret = merge_num;
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}
template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
int dev_id = resource_->dev_id(gpu_num);
auto& storage = storage_[gpu_num];
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
int max_size = 0;
ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num];
// alloc for size
int h_node_len[node_size_];
auto d_node_len_mem = memory::AllocShared(place, node_size_ * sizeof(int));
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
h_node_len[0] = len;
cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice);
// allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_,
cudaMemcpyDeviceToHost);
for (int i = 0; i < node_size_; ++i) {
if (h_node_len[i] > max_size) {
max_size = h_node_len[i];
}
}
storage.alloc(max_size * node_size_);
// allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
int merge_num = 0;
for (int i = 0; i < node_size_; ++i) {
int index = i * max_size;
cudaMemcpyAsync(storage.local_keys + merge_num, storage.all_keys + index,
h_node_len[i], cudaMemcpyDefault, stream);
cudaMemcpyAsync(storage.local_grads + merge_num, storage.all_grads + index,
h_node_len[i], cudaMemcpyDefault, stream);
merge_num += h_node_len[i];
}
int ret = merge_num;
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
return ret;
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::end_pass() {
int total_gpu = resource_->total_gpu();
std::vector<std::thread> threads;
auto dump_to_cpu_func = [this](int index) {
auto stream = resource_->local_stream(index, 0);
int dev_id = resource_->dev_id(index);
platform::CUDADeviceGuard guard(dev_id);
tables_[index]->dump_to_cpu(dev_id, stream);
};
for (int i = 0; i < total_gpu; ++i) {
threads.push_back(std::thread(dump_to_cpu_func, i));
}
for (auto& t : threads) {
t.join();
}
}
// template <typename KeyType, typename ValType, typename GradType>
// void HeterComm<KeyType, ValType, GradType>::dump_to_cpu(int index) {
// auto stream = resource_->local_stream(index, 0);
// int dev_id = resource_->dev_id(index);
// platform::CUDADeviceGuard guard(dev_id);
// tables_[index]->dump_to_cpu(dev_id, stream);
//}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -48,13 +48,20 @@ int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid);
}
void HeterPs::dump() {}
void HeterPs::end_pass() { comm_->end_pass(); }
void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }
void HeterPs::push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) {
comm_->push_sparse(num, d_keys, d_grads, len, opt_);
// comm_->push_sparse(num, d_keys, d_grads, len, opt_);
comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_);
}
void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
}
} // end namespace framework
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#ifdef PADDLE_WITH_PSLIB
......@@ -35,7 +35,10 @@ class HeterPs : public HeterPsBase {
size_t len) override;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) override;
virtual void dump() override;
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) override;
virtual void end_pass() override;
virtual int get_index_by_devid(int devid) override;
virtual void show_one_table(int gpu_num) override;
virtual void push_sparse(int num, FeatureKey* d_keys,
......
......@@ -35,7 +35,10 @@ class HeterPsBase {
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0;
virtual void dump() = 0;
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) = 0;
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <curand_kernel.h>
#include <vector>
#include "optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
......@@ -106,9 +107,12 @@ class Optimizer {
optimizer_config::clk_coeff * val.clk) {
val.mf_size = MF_DIM + 1;
val.mf[0] = 0;
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
curandState state;
curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < MF_DIM; ++i) {
val.mf[i + 1] = (cuda_normal_random((int)grad.show) * 2 - 1) *
optimizer_config::mf_initial_range;
val.mf[i + 1] =
(curand_uniform(&state)) * optimizer_config::mf_initial_range;
}
}
} else {
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using namespace paddle::framework;
......
......@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h"
#ifdef PADDLE_WITH_PSLIB
#include "paddle/fluid/framework/device_worker.h"
namespace paddle {
namespace framework {
......
......@@ -43,35 +43,214 @@ namespace framework {
std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL;
bool PSGPUWrapper::is_initialized_ = false;
void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim,
std::shared_ptr<HeterContext> gpu_task) {
void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
uint64_t table_id, int feature_dim) {
VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin";
platform::Timer timeline;
timeline.Start();
int shard_num = gpu_task->feature_keys_.size();
if (shard_num == 0) {
return;
int device_num = heter_devices_.size();
MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
gpu_task->init(thread_keys_shard_num_, device_num);
auto input_channel = dataset->GetInputChannel();
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
auto& device_keys = gpu_task->device_keys_;
auto& device_vals = gpu_task->device_values_;
auto& device_mutex = gpu_task->mutex_;
std::vector<std::thread> threads;
auto fleet_ptr = FleetWrapper::GetInstance();
// data should be in input channel
thread_keys_.resize(thread_keys_thread_num_);
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_keys_[i].resize(thread_keys_shard_num_);
for (int j = 0; j < thread_keys_shard_num_; j++) {
thread_keys_[i][j].reserve(2 * max_fea_num_per_pass_ /
thread_keys_shard_num_ /
thread_keys_thread_num_);
}
}
const std::deque<Record>& vec_data = input_channel->GetData();
size_t total_len = vec_data.size();
size_t len_per_thread = total_len / thread_keys_thread_num_;
int remain = total_len % thread_keys_thread_num_;
size_t begin = 0;
auto gen_func = [this](const std::deque<Record>& total_data, int begin_index,
int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins.uint64_feasigns_;
for (const auto feasign : feasign_v) {
uint64_t cur_key = feasign.sign().uint64_feasign_;
int shard_id = cur_key % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].push_back(cur_key);
}
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0),
i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
timeline.Start();
// merge thread_keys to shard_keys
for (size_t i = 0; i < thread_keys_.size(); i++) {
gpu_task->batch_add_keys(thread_keys_[i]);
for (int j = 0; j < thread_keys_thread_num_; j++) {
thread_keys_[i][j].clear();
}
}
timeline.Pause();
VLOG(1) << "GpuPs task unique11111 cost " << timeline.ElapsedSec()
<< " seconds.";
timeline.Start();
gpu_task->UniqueKeys();
timeline.Pause();
VLOG(1) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds.";
for (int i = 0; i < thread_keys_shard_num_; i++) {
VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size();
local_ptr[i].resize(local_keys[i].size());
}
auto ptl_func = [this, &local_keys, &local_ptr, &table_id,
&fleet_ptr](int i) {
size_t key_size = local_keys[i].size();
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_ptr[i].data()), table_id,
local_keys[i].data(), key_size);
tt.wait();
auto status = tt.get();
// auto status = 0;
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(300);
exit(-1);
} else {
VLOG(3) << "FleetWrapper Pull sparse to local done with table size: "
<< local_keys[i].size();
}
};
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(ptl_func, i);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds.";
timeline.Start();
auto build_func = [device_num, &local_keys, &local_ptr, &device_keys,
&device_vals, &device_mutex](int i) {
std::vector<std::vector<FeatureKey>> task_keys(device_num);
std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> task_ptrs(
device_num);
for (size_t j = 0; j < local_keys[i].size(); j++) {
int shard = local_keys[i][j] % device_num;
task_keys[shard].push_back(local_keys[i][j]);
task_ptrs[shard].push_back(local_ptr[i][j]);
}
for (int dev = 0; dev < device_num; dev++) {
device_mutex[dev]->lock();
int len = task_keys[dev].size();
int cur = device_keys[dev].size();
device_keys[dev].resize(device_keys[dev].size() + len);
device_vals[dev].resize(device_vals[dev].size() + len);
for (int j = 0; j < len; ++j) {
device_keys[dev][cur + j] = task_keys[dev][j];
float* ptr_val = task_ptrs[dev][j]->data();
FeatureValue& val = device_vals[dev][cur + j];
size_t dim = task_ptrs[dev][j]->size();
val.delta_score = ptr_val[1];
val.show = ptr_val[2];
val.clk = ptr_val[3];
val.slot = ptr_val[6];
val.lr = ptr_val[4];
val.lr_g2sum = ptr_val[5];
val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]);
if (dim > 7) {
val.mf_size = MF_DIM + 1;
for (int x = 0; x < val.mf_size; x++) {
val.mf[x] = ptr_val[x + 7];
}
} else {
val.mf_size = 0;
for (int x = 0; x < MF_DIM + 1; x++) {
val.mf[x] = 0;
}
}
}
std::vector<size_t> feature_keys_count(shard_num);
device_mutex[dev]->unlock();
}
};
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(build_func, i);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec()
<< " seconds.";
}
void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
int device_num = heter_devices_.size();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
BuildTask(gpu_task, table_id, feature_dim);
platform::Timer timeline;
timeline.Start();
std::vector<size_t> feature_keys_count(device_num);
size_t size_max = 0;
for (int i = 0; i < shard_num; i++) {
feature_keys_count[i] = gpu_task->feature_keys_[i].size();
for (int i = 0; i < device_num; i++) {
feature_keys_count[i] = gpu_task->device_keys_[i].size();
size_max = std::max(size_max, feature_keys_count[i]);
}
if (HeterPs_) {
HeterPs_->show_one_table(0);
return;
}
std::vector<std::thread> threads(device_num);
HeterPs_ = HeterPsBase::get_instance(size_max, resource_);
for (int i = 0; i < shard_num; ++i) {
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
auto build_func = [this, &gpu_task, &feature_keys_count](int i) {
std::cout << "building table: " << i << std::endl;
HeterPs_->build_ps(i, gpu_task->feature_keys_[i].data(),
gpu_task->feature_values_[i].data(),
feature_keys_count[i], 10000, 2);
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(),
gpu_task->device_values_[i].data(),
feature_keys_count[i], 500000, 2);
HeterPs_->show_one_table(i);
};
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(build_func, i);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec()
VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec()
<< " s.";
}
......
......@@ -25,12 +25,18 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
......@@ -73,14 +79,53 @@ class PSGPUWrapper {
const int hidden_size, const int64_t total_length,
const int batch_size);
void BuildGPUPS(const uint64_t table_id, int feature_dim,
std::shared_ptr<HeterContext> context);
void BuildGPUPS(const uint64_t table_id, int feature_dim);
void BuildTask(std::shared_ptr<HeterContext> gpu_task, uint64_t table_id,
int feature_dim);
void InitializeGPU(const std::vector<int>& dev_ids) {
if (s_instance_ != NULL) {
if (s_instance_ != NULL && is_initialized_ == false) {
VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
is_initialized_ = true;
resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p();
keys_tensor.resize(resource_->total_gpu());
if (multi_node_) {
int dev_size = dev_ids.size();
// init inner comm
inner_comms_.resize(dev_size);
inter_ncclids_.resize(dev_size);
platform::dynload::ncclCommInitAll(&(inner_comms_[0]), dev_size,
&dev_ids[0]);
// init inter comm
#ifdef PADDLE_WITH_GLOO
inter_comms_.resize(dev_size);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Rank() == 0) {
for (int i = 0; i < dev_size; ++i) {
platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
}
}
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
gloo::BroadcastOptions opts(gloo->GetContext());
opts.setOutput(&inter_ncclids_[0], dev_size);
opts.setRoot(0);
gloo::broadcast(opts);
for (int i = 0; i < dev_size; ++i) {
platform::dynload::ncclCommInitRank(&inter_comms_[i], gloo->Size(),
inter_ncclids_[i], gloo->Rank());
}
node_size_ = gloo->Size();
#else
PADDLE_THROW(
platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
}
heter_devices_ = dev_ids;
}
}
// PSGPUWrapper singleton
......@@ -98,6 +143,9 @@ class PSGPUWrapper {
slot_vector_ = slot_vector;
}
void EndPass() { HeterPs_->end_pass(); }
void ShowOneTable(int index) { HeterPs_->show_one_table(index); }
private:
static std::shared_ptr<PSGPUWrapper> s_instance_;
std::unordered_map<
......@@ -108,6 +156,18 @@ class PSGPUWrapper {
std::shared_ptr<HeterPsResource> resource_;
int32_t sleep_seconds_before_fail_exit_;
std::vector<int> slot_vector_;
int multi_node_{1};
int node_size_;
std::vector<ncclComm_t> inner_comms_;
std::vector<ncclComm_t> inter_comms_;
std::vector<ncclUniqueId> inter_ncclids_;
std::vector<int> heter_devices_;
std::unordered_set<std::string> gpu_ps_config_keys_;
HeterObjectPool<HeterContext> gpu_task_pool_;
std::vector<std::vector<std::vector<uint64_t>>> thread_keys_;
int thread_keys_thread_num_ = 37;
int thread_keys_shard_num_ = 37;
uint64_t max_fea_num_per_pass_ = 5000000000;
protected:
static bool is_initialized_;
......
......@@ -74,8 +74,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num);
}
auto gpu_ps_wrapper = PSGPUWrapper::GetInstance();
gpu_ps_wrapper->InitializeGPU(dev_ids);
return;
}
......
......@@ -37,6 +37,16 @@ void BindPSGPUWrapper(py::module* m) {
*m, "PSGPU")
.def(py::init([]() { return framework::PSGPUWrapper::GetInstance(); }))
.def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector,
py::call_guard<py::gil_scoped_release>())
.def("init_GPU_server", &framework::PSGPUWrapper::InitializeGPUServer,
py::call_guard<py::gil_scoped_release>())
.def("set_dataset", &framework::PSGPUWrapper::SetDataset,
py::call_guard<py::gil_scoped_release>())
.def("init_gpu_ps", &framework::PSGPUWrapper::InitializeGPU,
py::call_guard<py::gil_scoped_release>())
.def("end_pass", &framework::PSGPUWrapper::EndPass,
py::call_guard<py::gil_scoped_release>())
.def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS,
py::call_guard<py::gil_scoped_release>());
} // end PSGPUWrapper
#endif
......
......@@ -599,6 +599,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
ip_port = kwargs.get("http_ip_port", "")
self._use_ps_gpu = kwargs.get("use_ps_gpu", False)
self._http_ip_port = []
self._http_server = None
# if ip_port is not empty, it will use http instead of hdfs
......@@ -666,6 +667,18 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
if self._use_ps_gpu:
Gloo_strategy = fluid.core.GlooParallelStrategy()
Gloo_strategy.rank = current_id
Gloo_strategy.rank_num = len(worker_endpoints)
Gloo_strategy.ip_address = self._http_ip_port[0]
Gloo_strategy.ip_port = int(self._http_ip_port[1])
Default_init_timeout_seconds = 3600
Default_run_timeout_seconds = 9999999
Gloo_strategy.init_seconds = Default_init_timeout_seconds
Gloo_strategy.run_seconds = Default_run_timeout_seconds
Gloo = fluid.core.GlooParallelContext(Gloo_strategy)
Gloo.init()
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER":
......
......@@ -386,3 +386,27 @@ class SingleProcessMultiThread(GradAllReduce):
def _transpile_startup_program(self):
block = self.startup_program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
class MultiThread(GradAllReduce):
'''
'''
def __init__(self, nrings=1):
GradAllReduce.__init__(self, nrings)
self.mode = "box"
def _transpile_startup_program(self):
if len(self.endpoints) > 1:
print("begin to _transpile_startup_program for multi-node")
print("current_endpoint: ", self.current_endpoint)
print("total endpoints: ", self.endpoints)
print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
for ring_id in range(self.nrings):
self._init_communicator(
self.startup_program, self.current_endpoint, self.endpoints,
self.rank, ring_id, self.wait_port, True)
else:
print("begin to _transpile_startup_program for single-node")
block = self.startup_program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册