未验证 提交 3f619290 编写于 作者: Y yaoxuefeng 提交者: GitHub

merge dymf branch (#42714)

merge dymf branch
上级 e726960a
......@@ -129,11 +129,6 @@ class HeterContext {
for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
feature_dim_keys_[i].resize(dim_num);
value_dim_ptr_[i].resize(dim_num);
if (i == 0) {
for (int j = 0; j < dim_num; j++) {
feature_dim_keys_[i][j].push_back(0);
}
}
}
device_values_.resize(device_num);
device_dim_values_.resize(device_num);
......
......@@ -32,17 +32,33 @@ struct FeatureValue {
float lr;
float lr_g2sum;
int mf_size;
float mf[MF_DIM + 1];
int mf_dim;
uint64_t cpu_ptr;
float mf[0];
friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
<< " lr: " << val.lr << " mf_size: " << val.mf_size << " mf:";
for (int i = 0; i < val.mf_size; ++i) {
<< " lr: " << val.lr << " mf_dim: " << val.mf_dim
<< "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:";
for (int i = 0; i < val.mf_dim + 1; ++i) {
out << " " << val.mf[i];
}
return out;
}
__device__ __forceinline__ void operator=(const FeatureValue& in) {
delta_score = in.delta_score;
show = in.show;
clk = in.clk;
slot = in.slot;
lr = in.lr;
lr_g2sum = in.lr_g2sum;
mf_size = in.mf_size;
mf_dim = in.mf_dim;
cpu_ptr = in.cpu_ptr;
for (int i = 0; i < mf_dim + 1; i++) {
mf[i] = in.mf[i];
}
}
};
struct FeaturePushValue {
......@@ -50,20 +66,19 @@ struct FeaturePushValue {
float clk;
int slot;
float lr_g;
float mf_g[MF_DIM];
int mf_dim;
float mf_g[0];
// __device__ __forceinline__ FeaturePushValue
// operator+(const FeaturePushValue& a) const {
// FeaturePushValue out;
// out.slot = a.slot;
// out.show = a.show + show;
// out.clk = a.clk + clk;
// out.lr_g = a.lr_g + lr_g;
// for (int i = 0; i < MF_DIM; ++i) {
// out.mf_g[i] = a.mf_g[i] + mf_g[i];
// }
// return out;
// }
__device__ __forceinline__ void operator=(const FeaturePushValue& in) {
show = in.show;
clk = in.clk;
slot = in.slot;
lr_g = in.lr_g;
mf_dim = in.mf_dim;
for (int i = 0; i < mf_dim; i++) {
mf_g[i] = in.mf_g[i];
}
}
};
} // end namespace framework
......
......@@ -118,8 +118,8 @@ class HashTable {
StreamType stream);
template <typename StreamType>
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
StreamType stream);
void insert(const KeyType* d_keys, size_t len, char* pool,
size_t feature_value_size, size_t start_index, StreamType stream);
template <typename StreamType>
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
......
......@@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table,
template <typename Table>
__global__ void insert_kernel(Table* table,
const typename Table::key_type* const keys,
size_t len, char* pool, int start_index) {
size_t len, char* pool, size_t feature_value_size,
int start_index) {
ReplaceOp<typename Table::mapped_type> op;
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
......@@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table,
if (i < len) {
kv.first = keys[i];
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
uint64_t offset = uint64_t(start_index + i) * feature_value_size;
kv.second = (Table::mapped_type)(pool + offset);
auto it = table->insert(kv, op);
assert(it != table->end() && "error: insert fails: table is full");
}
......@@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table,
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys,
char* const vals, size_t len,
char* vals, size_t len,
size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
uint64_t offset = i * pull_feature_value_size;
FeatureValue& cur = *(FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
}
}
}
......@@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table,
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
} else {
printf("yxf::push miss key: %d", keys[i]);
printf("warning: push miss key: %d", keys[i]);
}
}
}
......@@ -201,7 +205,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index,
char* pool, size_t feature_value_size,
size_t start_index,
StreamType stream) {
if (len == 0) {
return;
......@@ -210,8 +215,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
pool, start_index);
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, len, pool, feature_value_size, start_index);
}
template <typename KeyType, typename ValType>
......@@ -319,6 +324,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
template class HashTable<long, int>;
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
......@@ -331,6 +337,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len,
cudaStream_t stream);
......@@ -354,6 +364,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
const paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
insert<cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
size_t feature_value_size, size_t start_index,
cudaStream_t stream);
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
const int* d_vals,
size_t len,
......@@ -393,6 +408,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
sgd,
cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::update<
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t len,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>
sgd,
cudaStream_t stream);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue,
......
......@@ -15,10 +15,13 @@ limitations under the License. */
#pragma once
#include <thread>
#include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/timer.h"
#include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
......@@ -38,6 +41,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
template <typename KeyType, typename ValType, typename GradType>
class HeterComm {
public:
......@@ -50,9 +56,13 @@ class HeterComm {
int* left, int* right, int gpu_num);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads,
size_t len, int& uniq_len);
void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num);
void build_ps(int num, KeyType* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size, int stream_num);
void dump();
void show_one_table(int gpu_num);
int get_index_by_devid(int devid);
......@@ -96,6 +106,11 @@ class HeterComm {
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
multi_mf_dim_ = multi_mf_dim;
max_mf_dim_ = max_mf_dim;
}
#endif
bool need_transfer(int send_id, int receive_id) {
......@@ -114,8 +129,8 @@ class HeterComm {
char* key_storage;
char* val_storage;
int sync;
int key_bytes_len;
int val_bytes_len;
size_t key_bytes_len;
size_t val_bytes_len;
int dev_num;
};
......@@ -206,12 +221,18 @@ class HeterComm {
void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, char* src_val, size_t val_size);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
ValType* src_val);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
char* src_val, size_t val_size);
protected:
using Table = HashTable<KeyType, ValType>;
using PtrTable = HashTable<KeyType, ValType*>;
std::vector<Table*> tables_;
std::vector<PtrTable*> ptr_tables_;
std::shared_ptr<HeterPsResource> resource_;
std::vector<std::vector<Path>> path_;
float load_factor_{0.75};
......@@ -221,6 +242,7 @@ class HeterComm {
private:
int topo_aware_{0};
std::vector<LocalStorage> storage_;
DynamicGradMerger merger_;
int feanum_{1800 * 2048};
int multi_node_{0};
int node_size_;
......@@ -228,6 +250,8 @@ class HeterComm {
#if defined(PADDLE_WITH_CUDA)
std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_;
int multi_mf_dim_{8};
int max_mf_dim_ = 8;
std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
#endif
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <queue>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_XPU_KP
......@@ -22,20 +23,31 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
resource_ = resource;
storage_.resize(resource_->total_device());
multi_mf_dim_ = resource->multi_mf();
for (int i = 0; i < resource_->total_device(); ++i) {
#if defined(PADDLE_WITH_CUDA)
platform::CUDADeviceGuard guard(resource_->dev_id(i));
allocators_.push_back(std::make_shared<cub::CachingDeviceAllocator>(
8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT
#endif
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
if (!multi_mf_dim_) {
auto table = new Table(capacity / load_factor_);
tables_.push_back(table);
} else {
max_mf_dim_ = resource_->max_mf_dim();
size_t val_type_size = TYPEALIGN(
8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
size_t grad_type_size = TYPEALIGN(
8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto ptr_table = new PtrTable(capacity / load_factor_);
ptr_table->set_feature_value_size(val_type_size, grad_type_size);
ptr_tables_.push_back(ptr_table);
}
if (multi_node_) {
storage_[i].init(feanum_, resource_->dev_id(i));
}
......@@ -238,95 +250,128 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index,
int num, int* h_left,
int* h_right,
ValType* src_val) {
void HeterComm<KeyType, ValType, GradType>::walk_to_dest(
int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key,
char* src_val, size_t val_size) {
int need_copy_val = 0;
if (src_val) {
need_copy_val = 1;
}
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
int size = path_[start_index][i].nodes_.size();
auto& node = path_[start_index][i].nodes_[0];
CopyTask t(&path_[start_index][i], 0);
que.push(t);
cudaMemcpyAsync(node.key_storage,
reinterpret_cast<char*>(src_key + h_left[i]),
node.key_bytes_len, cudaMemcpyDefault, node.in_stream);
if (need_copy_val) {
cudaMemcpyAsync(node.val_storage,
src_val + uint64_t(h_left[i]) * uint64_t(val_size),
node.val_bytes_len, cudaMemcpyDefault, node.in_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
if (cur_task.path->nodes_[cur_task.step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream);
}
if (cur_task.step != cur_task.path->nodes_.size() - 1) {
int cur_step = cur_task.step;
CopyTask c(cur_task.path, cur_step + 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage,
cur_task.path->nodes_[cur_step].key_storage,
cur_task.path->nodes_[cur_step + 1].key_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
if (need_copy_val) {
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step + 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
}
}
}
}
for (int i = 0; i < num; i++) {
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(
int start_index, int gpu_num, int* h_left, int* h_right, char* src_val,
size_t val_size) {
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
int cur_step = path_[start_index][i].nodes_.size() - 1;
auto& node = path_[start_index][i].nodes_[cur_step];
auto src_dev_id = resource_->dev_id(i);
auto src_place = DevPlace(src_dev_id);
if (cur_step == 0) {
auto dst_dev_id = resource_->dev_id(start_index);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, reinterpret_cast<char*>(src_val + h_left[i]),
src_place, node.val_storage, node.val_bytes_len,
node.out_stream);
cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size,
node.val_storage, node.val_bytes_len, cudaMemcpyDefault,
node.out_stream);
} else {
CopyTask t(&path_[start_index][i], cur_step - 1);
que.push(t);
auto dst_dev_id =
resource_->dev_id(path_[start_index][i].nodes_[cur_step - 1].dev_num);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place,
path_[start_index][i].nodes_[cur_step - 1].val_storage,
src_place, node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
path_[start_index][i].nodes_[cur_step - 1].out_stream);
cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage,
node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[cur_step - 1].out_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
int cur_step = cur_task.step;
if (cur_task.path->nodes_[cur_step].sync) {
sync_stream(cur_task.path->nodes_[cur_step].out_stream);
cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream);
}
auto src_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num);
auto src_place = DevPlace(src_dev_id);
if (cur_step > 0) {
CopyTask c(cur_task.path, cur_step - 1);
que.push(c);
auto dst_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step - 1].dev_num);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, cur_task.path->nodes_[cur_step - 1].val_storage,
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cur_task.path->nodes_[cur_step - 1].out_stream);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step - 1].out_stream);
} else if (cur_step == 0) {
int end_index = cur_task.path->nodes_.back().dev_num;
auto dst_dev_id = resource_->dev_id(end_index);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place,
reinterpret_cast<char*>(src_val + h_left[end_index]),
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step].val_bytes_len,
cur_task.path->nodes_[cur_step].out_stream);
cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step].out_stream);
}
}
}
template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::~HeterComm() {
for (auto& table : tables_) {
delete table;
table = nullptr;
if (!multi_mf_dim_) {
for (auto& table : tables_) {
delete table;
table = nullptr;
}
} else {
for (auto& table : ptr_tables_) {
delete table;
table = nullptr;
}
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::show_one_table(int num) {
tables_[num]->show();
void HeterComm<KeyType, ValType, GradType>::show_one_table(int gpu_num) {
if (!multi_mf_dim_) {
tables_[gpu_num]->show();
}
}
template <typename KeyType, typename ValType, typename GradType>
......@@ -418,59 +463,165 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
char* pool, size_t len,
size_t feature_value_size,
size_t chunk_size,
int stream_num) {
if (len <= 0) {
return;
}
int dev_id = resource_->dev_id(num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
std::vector<memory::allocation::AllocationPtr> d_key_bufs;
ppStream streams[stream_num]; // NOLINT
for (int i = 0; i < stream_num; ++i) {
create_stream(&(streams[i]));
auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType));
d_key_bufs.push_back(std::move(d_k_buf));
}
int cur_len = 0;
int cur_stream = 0;
while (cur_len < len) {
cur_stream = cur_stream % stream_num;
auto cur_use_stream = streams[cur_stream];
#if defined(PADDLE_WITH_XPU_KP)
cur_use_stream = 0;
#endif
int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size;
auto dst_place = place;
auto src_place = platform::CPUPlace();
memory_copy(
dst_place, reinterpret_cast<char*>(d_key_bufs[cur_stream]->ptr()),
src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream);
ptr_tables_[num]->insert(
reinterpret_cast<KeyType*>(d_key_bufs[cur_stream]->ptr()), tmp_len,
pool, feature_value_size, cur_len, cur_use_stream);
cur_stream += 1;
cur_len += tmp_len;
}
for (int i = 0; i < stream_num; ++i) {
sync_stream(streams[i]);
destroy_stream(streams[i]);
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::merge_grad(
int dev_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len) { // NOLINT
int dev_id = resource_->dev_id(dev_num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0);
size_t temp_storage_bytes;
auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr());
heter_comm_kernel_->sort_pairs(NULL, temp_storage_bytes, d_keys,
d_merge_keys_ptr, d_grads, d_merge_grads_ptr,
len, 0, 8 * sizeof(KeyType), stream, false);
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
heter_comm_kernel_->sort_pairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false);
temp_storage_bytes = 0;
auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int));
int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr());
heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, d_merge_keys_ptr,
d_keys, d_merge_grads_ptr, d_grads,
d_num_runs_out, len, stream, false);
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
heter_comm_kernel_->reduce_by_key(
d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys,
d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false);
auto dst_place = platform::CPUPlace();
auto src_place = place;
memory_copy(dst_place, &uniq_len, src_place, d_num_runs_out, sizeof(int),
stream);
sync_stream(stream);
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len) {
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
size_t temp_storage_bytes;
size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::Alloc(place, len * grad_value_size);
GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr());
auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1));
uint32_t* d_fea_num_info_ptr =
reinterpret_cast<uint32_t*>(d_fea_num_info->ptr());
uint32_t* d_index = (uint32_t*)&d_fea_num_info_ptr[len];
uint32_t* d_idx = (uint32_t*)&d_index[len];
int* d_merged_size = (int*)&d_idx[len];
int grid_size = (len - 1) / block_size_ + 1;
heter_comm_kernel_->fill_idx(d_idx, len, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len,
0, 8 * sizeof(KeyType), stream));
void* d_buff = NULL;
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_fea_num_info_ptr,
d_merged_size, len, stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys,
d_fea_num_info_ptr, d_merged_size, len, stream));
cudaMemcpyAsync((void*)&uniq_len, d_merged_size, sizeof(int),
cudaMemcpyDeviceToHost, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
assert(d_merged_size > 0);
uint32_t* d_offset = (uint32_t*)&d_index[len];
temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len,
stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset,
uniq_len, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
heter_comm_kernel_->merge_gradient(
d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
grad_value_size * uniq_len,
cudaMemcpyDeviceToDevice, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right,
......@@ -529,8 +680,6 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(num, 0);
// int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_device]; // NOLINT
int h_right[total_device]; // NOLINT
......@@ -562,10 +711,11 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
size_t val_type_size =
TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * sizeof(ValType));
auto d_shard_vals = memory::Alloc(place, len * val_type_size);
ValType* d_shard_vals_ptr = reinterpret_cast<ValType*>(d_shard_vals->ptr());
split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num);
......@@ -589,9 +739,8 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
continue;
}
create_storage(num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(ValType));
shard_len * val_type_size);
}
walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL);
for (int i = 0; i < total_device; ++i) {
......@@ -600,14 +749,11 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
auto& node = path_[num][i].nodes_.back();
sync_stream(node.in_stream);
AnyDeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->RDLock();
tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<ValType*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
ptr_tables_[i]->rwlock_->RDLock();
ptr_tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
}
for (int i = 0; i < total_device; ++i) {
......@@ -615,21 +761,18 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
if (h_left[i] == -1) {
continue;
}
tables_[i]->rwlock_->UNLock();
ptr_tables_[i]->rwlock_->UNLock();
}
walk_to_src(num, total_device, h_left, h_right, d_shard_vals_ptr);
walk_to_src(num, total_device, h_left, h_right,
reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
for (int i = 0; i < total_device; ++i) {
auto& node = path_[num][i].nodes_.front();
sync_stream(node.out_stream);
}
heter_comm_kernel_->fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
stream);
heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
val_type_size, stream);
sync_stream(stream);
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
......@@ -653,6 +796,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
int total_device = resource_->total_device();
int dev_id = resource_->dev_id(dev_num);
size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0);
......@@ -691,21 +836,19 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_shard_grads_ptr =
reinterpret_cast<GradType*>(d_shard_grads->ptr());
GradType* d_shard_grads_ptr;
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
int uniq_len = len;
merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
// int grid_size = (uniq_len - 1) / block_size_ + 1;
int grid_size = (uniq_len - 1) / block_size_ + 1;
split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
dev_num);
heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys,
d_shard_grads_ptr, d_grads, d_idx_ptr,
uniq_len, stream);
heter_comm_kernel_->dy_mf_fill_shard_grads(
d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len,
grad_value_size, stream);
sync_stream(stream);
......@@ -721,12 +864,22 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType));
if (!multi_mf_dim_) {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType));
} else {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * grad_value_size);
}
}
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
if (!multi_mf_dim_) {
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
} else {
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
reinterpret_cast<char*>(d_shard_grads_ptr), grad_value_size);
}
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
......@@ -736,17 +889,28 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
sync_stream(node.in_stream);
AnyDeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num));
if (!multi_mf_dim_) {
tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num));
} else {
ptr_tables_[i]->rwlock_->WRLock();
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num));
}
}
for (int i = 0; i < total_device; ++i) {
sync_stream(resource_->remote_stream(i, dev_num));
if (h_left[i] != -1) {
tables_[i]->rwlock_->UNLock();
if (!multi_mf_dim_) {
tables_[i]->rwlock_->UNLock();
} else {
ptr_tables_[i]->rwlock_->UNLock();
}
}
}
......@@ -1078,11 +1242,13 @@ void HeterComm<KeyType, ValType, GradType>::end_pass() {
tables_[index]->dump_to_cpu(dev_id, stream);
};
for (int i = 0; i < total_device; ++i) {
threads.push_back(std::thread(dump_to_cpu_func, i));
}
for (auto& t : threads) {
t.join();
if (!multi_mf_dim_) {
for (int i = 0; i < total_device; ++i) {
threads.push_back(std::thread(dump_to_cpu_func, i));
}
for (auto& t : threads) {
t.join();
}
}
}
......
......@@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
}
}
template <typename KeyType, typename GradType, typename T>
__global__ void dy_mf_fill_shard_grads_kernel(
KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads,
GradType* d_grads, T* idx, size_t len, size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
*(GradType*)((char*)d_shard_grads + i * grad_value_size) =
*(GradType*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
}
}
__global__ void merge_gradients_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
size_t grad_value_size,
DynamicGradMerger& merger_) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
int ori_index = index[start];
FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size);
FeaturePushValue& in =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in);
}
}
}
template <typename ValType, typename T>
__global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
T* idx, size_t len, size_t val_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
uint64_t new_offset = uint64_t(idx[i]) * val_size;
*(ValType*)((char*)d_vals + new_offset) =
*(ValType*)((char*)d_shard_vals + i * val_size);
}
}
// cuda implemention of heter_comm_kernel.h
template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx, long long len,
......@@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage,
debug_synchronous));
}
template <typename KeyType, typename GradType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_shard_grads(
KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads,
GradType* d_grads, T* idx, long long len, size_t grad_value_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len,
grad_value_size);
}
template <typename StreamType>
void HeterCommKernel::merge_gradient(
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
offset, fea_num, index, input, output, n, grad_value_size, merger_);
}
template <typename ValType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals,
T* idx, long long len, size_t val_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals, d_vals, idx, c_len, val_size);
}
template void HeterCommKernel::fill_idx<int, cudaStream_t>(
int* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::fill_idx<uint32_t, cudaStream_t>(
uint32_t* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::calc_shard_offset<int, cudaStream_t>(
int* idx, int* left, int* right, long long len, int total_devs,
......@@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key<
paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out,
int num_items, cudaStream_t stream, bool debug_synchronous);
template void HeterCommKernel::dy_mf_fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, cudaStream_t>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
size_t grad_value_size, const cudaStream_t& stream);
template void HeterCommKernel::merge_gradient<cudaStream_t>(
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);
template void HeterCommKernel::dy_mf_fill_dvals<paddle::framework::FeatureValue,
int, cudaStream_t>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len,
size_t val_size, const cudaStream_t& stream);
#endif
} // namespace framework
......
......@@ -27,6 +27,42 @@ limitations under the License. */
namespace paddle {
namespace framework {
struct DynamicGradMerger {
template <typename T>
CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
operator()(const T& a, const T& b) const {
T out;
out.slot = a.slot;
out.mf_dim = a.mf_dim;
out.show = a.show + b.show;
out.clk = a.clk + b.clk;
out.lr_g = a.lr_g + b.lr_g;
return out;
}
template <typename T>
__device__ __forceinline__ void update_one(T& output, const T& input) {
output.slot = input.slot;
output.show = input.show;
output.clk = input.clk;
output.mf_dim = input.mf_dim;
output.lr_g = input.lr_g;
for (int i = 0; i < output.mf_dim; ++i) {
output.mf_g[i] = input.mf_g[i];
}
}
template <typename T>
__device__ __forceinline__ void merge_one(T& output, const T& input) {
output.show += input.show;
output.clk += input.clk;
output.lr_g += input.lr_g;
for (int i = 0; i < input.mf_dim; ++i) {
output.mf_g[i] += input.mf_g[i];
}
}
};
class HeterCommKernel {
public:
HeterCommKernel() {}
......@@ -80,6 +116,24 @@ class HeterCommKernel {
StreamType stream = NULL, bool debug_synchronous = false);
template <typename KeyType, typename GradType, typename T,
typename StreamType>
void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads,
T* idx, long long len, size_t grad_value_size,
const StreamType& stream);
template <typename StreamType>
void merge_gradient(const uint32_t* offset, const uint32_t* fea_num,
const uint32_t* index, const char* input, char* output,
int n, size_t grad_value_size, DynamicGradMerger& merger_,
const StreamType& stream);
template <typename ValType, typename T, typename StreamType>
void dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
long long len, size_t val_size,
const StreamType& stream);
private:
int block_size_{256};
};
......
......@@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num);
}
void HeterPs::build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) {
comm_->build_ps(num, h_keys, pool, len, feature_value_size, chunk_size,
stream_num);
}
int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid);
}
......@@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
}
void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim);
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase {
size_t len) override;
void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len,
size_t chunk_size, int stream_num) override;
void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) override;
#if defined(PADDLE_WITH_CUDA)
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) override;
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override;
#endif
void set_sparse_sgd(const OptimizerConfig& optimizer_config) override;
......
......@@ -35,11 +35,15 @@ class HeterPsBase {
size_t len) = 0;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0;
virtual void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0;
#if defined(PADDLE_WITH_CUDA)
virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0;
#endif
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
......
......@@ -107,6 +107,8 @@ class HeterPsResource {
int get_index_by_devid(int devid);
int dev_id(int num);
void set_multi_mf(int multi_mf_dim, int max_mf_dim);
int multi_mf() { return multi_mf_dim_; }
int max_mf_dim() { return max_mf_dim_; }
ppStream local_stream(int dev_num, int stream_num);
ppStream remote_stream(int dev_num, int stream_num);
......
......@@ -125,20 +125,21 @@ class Optimizer {
if (optimizer_config.mf_create_thresholds <=
optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) +
optimizer_config.clk_coeff * ptr->clk) {
// ptr->mf_size = ptr->mf_dim + 1;
ptr->mf_size = ptr->mf_dim + 1;
ptr->mf_size = MF_DIM + 1;
// ptr->mf_size = MF_DIM + 1;
ptr->mf[0] = 0;
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
curandState state;
curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < MF_DIM; ++i) {
for (int i = 0; i < ptr->mf_dim; ++i) {
ptr->mf[i + 1] =
(curand_uniform(&state)) * optimizer_config.mf_initial_range;
}
}
} else {
update_mf(optimizer_config, MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g,
update_mf(optimizer_config, ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0],
grad.mf_g,
grad.show); // for local test
}
}
......
......@@ -31,7 +31,6 @@ limitations under the License. */
#include <algorithm>
#include <deque>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/platform/timer.h"
......@@ -112,12 +111,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
} else {
gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_);
}
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
std::vector<std::thread> threads;
// data should be in input channel
if (!multi_mf_dim_) {
thread_keys_.resize(thread_keys_thread_num_);
for (int i = 0; i < thread_keys_thread_num_; i++) {
......@@ -141,11 +136,9 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
std::string data_set_name = std::string(typeid(*dataset_).name());
if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset";
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel();
VLOG(0) << "yxf::buildtask::inputslotchannle size: "
<< input_channel->Size();
VLOG(0) << "psgpu wrapperinputslotchannle size: " << input_channel->Size();
const std::deque<SlotRecord>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
......@@ -176,21 +169,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
j < slot_offset[slot_offset_vector_[slot_idx] + 1]; j++) {
int shard_id = feasign_v[j] % thread_keys_shard_num_;
int dim_id = slot_index_vec_[slot_idx];
this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]);
if (feasign_v[j] != 0) {
this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]);
}
}
}
}
/*
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_dim_keys_[i][shard_id][0].insert(feasign);
}
}
*/
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
if (!multi_mf_dim_) {
......@@ -264,12 +248,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
thread_dim_keys_[i][shard_num][dim_id].clear();
}
};
// for (size_t i = 0; i < thread_keys_.size(); i++) {
// gpu_task->batch_add_keys(thread_keys_[i]);
// for (int j = 0; j < thread_keys_thread_num_; j++) {
// thread_keys_[i][j].clear();
// }
//}
for (int i = 0; i < thread_keys_shard_num_; ++i) {
if (!multi_mf_dim_) {
threads.push_back(std::thread(merge_ins_func, i));
......@@ -291,20 +269,15 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
timeline.Pause();
VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds.";
if (!multi_mf_dim_) {
for (int i = 0; i < thread_keys_shard_num_; i++) {
VLOG(0) << "GpuPs shard: " << i << " key len: " << local_keys[i].size();
local_ptr[i].resize(local_keys[i].size());
}
} else {
for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j]
<< " key len: " << gpu_task->feature_dim_keys_[i][j].size();
gpu_task->value_dim_ptr_[i][j].resize(
gpu_task->feature_dim_keys_[i][j].size());
for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
if (i == 0 && j == multi_mf_dim_ - 1) {
gpu_task->feature_dim_keys_[i][j].push_back(0);
}
VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j]
<< " key len: " << gpu_task->feature_dim_keys_[i][j].size();
gpu_task->value_dim_ptr_[i][j].resize(
gpu_task->feature_dim_keys_[i][j].size());
}
}
}
......@@ -353,85 +326,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
#endif
timeline.Start();
auto ptl_func = [this, &local_keys, &local_ptr, &fleet_ptr](int i) {
size_t key_size = local_keys[i].size();
int32_t status = -1;
#ifdef PADDLE_WITH_PSLIB
// auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
// reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
// local_keys[i].data(), key_size);
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
i, reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
bool flag = true;
tt.wait();
try {
status = tt.get();
} catch (const std::future_error& e) {
VLOG(0) << "Caught a future_error with code" << e.code()
<< ", Message:" << e.what();
}
if (status != 0) {
VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
flag = false;
cnt++;
}
if (cnt > 3) {
VLOG(0) << "fleet pull sparse failed, retry 3 times";
exit(-1);
}
if (flag) {
break;
}
}
#endif
#ifdef PADDLE_WITH_PSCORE
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->worker_ptr_->PullSparsePtr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
bool flag = true;
tt.wait();
try {
status = tt.get();
} catch (const std::future_error& e) {
VLOG(0) << "Caught a future_error with code" << e.code()
<< ", Message:" << e.what();
}
if (status != 0) {
VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
flag = false;
cnt++;
}
if (cnt > 3) {
VLOG(0) << "fleet pull sparse failed, retry 3 times";
exit(-1);
}
if (flag) {
break;
}
}
#endif
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(300);
exit(-1);
} else {
VLOG(3) << "FleetWrapper Pull sparse to local done with table size: "
<< local_keys[i].size();
}
};
auto ptl_dynamic_mf_func = [this, &local_dim_keys, &local_dim_ptr,
&fleet_ptr](int i, int j) {
......@@ -478,21 +372,18 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
}
#endif
};
if (!multi_mf_dim_) {
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(ptl_func, i);
}
} else {
threads.resize(thread_keys_shard_num_ * multi_mf_dim_);
for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i * multi_mf_dim_ + j] = std::thread(ptl_dynamic_mf_func, i, j);
}
threads.resize(thread_keys_shard_num_ * multi_mf_dim_);
for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
task_futures.emplace_back(
pull_thread_pool_[i]->enqueue(ptl_dynamic_mf_func, i, j));
}
}
for (std::thread& t : threads) {
t.join();
for (auto& f : task_futures) {
f.wait();
}
task_futures.clear();
timeline.Pause();
VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec()
<< " seconds.";
......@@ -509,19 +400,12 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
std::vector<std::vector<std::pair<uint64_t, char*>>> pass_values;
bool record_status = false;
#ifdef PADDLE_WITH_PSLIB
uint16_t pass_id = 0;
if (multi_node_) {
record_status = fleet_ptr->pslib_ptr_->_worker_ptr->take_sparse_record(
table_id_, pass_id, pass_values);
}
#endif
auto& device_task_keys = gpu_task->device_task_keys_;
auto& device_task_ptrs = gpu_task->device_task_ptr_;
auto build_dynamic_mf_func = [this, device_num, &local_dim_keys,
&local_dim_ptr, &device_dim_keys,
&device_dim_ptr,
&device_dim_mutex](int i, int j) {
auto build_pull_dynamic_mf_func = [this, device_num, &local_dim_keys,
&local_dim_ptr, &device_dim_keys,
&device_dim_ptr,
&device_dim_mutex](int i, int j) {
#ifdef PADDLE_WITH_PSLIB
std::vector<std::vector<FeatureKey>> task_keys(device_num);
std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> task_ptrs(
......@@ -532,20 +416,16 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
task_ptrs[shard].push_back(local_dim_ptr[i][j][k]);
}
for (int dev = 0; dev < device_num; dev++) {
for (int dim = 0; dim < multi_mf_dim_; dim++) {
device_dim_mutex[dev][dim]->lock();
int len = task_keys[dev].size();
int cur = device_dim_keys[dev][dim].size();
device_dim_keys[dev][dim].resize(device_dim_keys[dev][dim].size() +
len);
device_dim_ptr[dev][dim].resize(device_dim_ptr[dev][dim].size() + len);
for (int k = 0; k < len; ++k) {
device_dim_keys[dev][dim][cur + k] = task_keys[dev][k];
device_dim_ptr[dev][dim][cur + k] = task_ptrs[dev][k];
}
device_dim_mutex[dev][dim]->unlock();
device_dim_mutex[dev][j]->lock();
int len = task_keys[dev].size();
int cur = device_dim_keys[dev][j].size();
device_dim_keys[dev][j].resize(device_dim_keys[dev][j].size() + len);
device_dim_ptr[dev][j].resize(device_dim_ptr[dev][j].size() + len);
for (int k = 0; k < len; ++k) {
device_dim_keys[dev][j][cur + k] = task_keys[dev][k];
device_dim_ptr[dev][j][cur + k] = task_ptrs[dev][k];
}
device_dim_mutex[dev][j]->unlock();
}
#endif
};
......@@ -697,7 +577,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i * multi_mf_dim_ + j] =
std::thread(build_dynamic_mf_func, i, j);
std::thread(build_pull_dynamic_mf_func, i, j);
}
}
for (std::thread& t : threads) {
......@@ -727,21 +607,17 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
std::vector<size_t> feature_keys_count(device_num);
size_t size_max = 0;
if (!multi_mf_dim_) {
for (int i = 0; i < device_num; i++) {
feature_keys_count[i] = gpu_task->device_keys_[i].size();
VLOG(0) << i << " card contains feasign nums: " << feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
}
} else {
for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size();
}
VLOG(0) << i << " card with dynamic mf contains feasign nums: "
<< feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size();
VLOG(1) << i << " card with dynamic mf dim: " << index_dim_vec_[j]
<< " dim index: " << j << " contains feasign nums: "
<< gpu_task->device_dim_ptr_[i][j].size();
}
VLOG(1) << i << " card with dynamic mf contains feasign nums total: "
<< feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
}
if (HeterPs_) {
delete HeterPs_;
......@@ -756,17 +632,73 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
#ifdef PADDLE_WITH_CUDA
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
#endif
auto build_func = [this, &gpu_task, &feature_keys_count](int i) {
VLOG(3) << "building table: " << i;
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(),
gpu_task->device_values_[i].data(),
feature_keys_count[i], 500000, 2);
// if (feature_keys_count[i] > 0) {
// HeterPs_->show_one_table(i);
// }
auto build_dynamic_mf_func = [this, &gpu_task](int i, int j) {
this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_);
int mf_dim = this->index_dim_vec_[j];
VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim;
size_t feature_value_size =
TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));
auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j];
size_t len = device_dim_keys.size();
CHECK(len == device_dim_ptrs.size());
this->mem_pools_[i * this->multi_mf_dim_ + j] =
new MemoryPool(len, feature_value_size);
auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j];
for (size_t k = 0; k < len; k++) {
FeatureValue* val = (FeatureValue*)(mem_pool->mem_address(k));
float* ptr_val = device_dim_ptrs[k]->data();
size_t dim = device_dim_ptrs[k]->size();
#ifdef PADDLE_WITH_PSLIB
val->delta_score =
ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::delta_score_index()];
val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::show_index()];
val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::click_index()];
val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::slot_index()]);
val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::embed_w_index()];
val->lr_g2sum =
ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::embed_g2sum_index()];
val->cpu_ptr = (uint64_t)(device_dim_ptrs[k]);
ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
mf_dim_index()] = float(mf_dim);
val->mf_dim = mf_dim;
#endif
if (dim > 8) { // CpuPS alreay expand as mf_dim
val->mf_size = mf_dim + 1;
for (int x = 0; x < val->mf_dim + 1; x++) {
val->mf[x] = ptr_val[x + 8];
}
} else {
val->mf_size = 0;
for (int x = 0; x < val->mf_dim + 1; x++) {
val->mf[x] = 0;
}
}
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool);
auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len,
feature_value_size, 500000, 2);
if (device_dim_keys.size() > 0) {
VLOG(0) << "show ptr table: " << i
<< " table kv size: " << device_dim_keys.size()
<< "dim: " << mf_dim << " len: " << len;
this->HeterPs_->show_one_table(i);
}
delete mem_pool;
};
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(build_func, i);
threads.resize(device_num * multi_mf_dim_);
for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j);
}
}
for (std::thread& t : threads) {
t.join();
......@@ -788,7 +720,6 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
if (is_shuffle) {
dataset_->LocalShuffle();
}
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
data_ready_channel_->Put(gpu_task);
......@@ -874,17 +805,86 @@ void PSGPUWrapper::EndPass() {
size_t keysize_max = 0;
// in case of feasign_num = 0, skip dump_to_cpu
for (size_t i = 0; i < heter_devices_.size(); i++) {
keysize_max = std::max(keysize_max, current_task_->device_keys_[i].size());
for (int j = 0; j < multi_mf_dim_; j++) {
keysize_max =
std::max(keysize_max, current_task_->device_dim_keys_[i][j].size());
}
}
auto dump_pool_to_cpu_func = [this](int i, int j) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i)));
auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
auto& device_keys = this->current_task_->device_dim_keys_[i][j];
size_t len = device_keys.size();
int mf_dim = this->index_dim_vec_[j];
VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim;
size_t feature_value_size =
TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));
char* test_build_values = (char*)malloc(feature_value_size * len);
cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len,
cudaMemcpyDeviceToHost);
CHECK(len == hbm_pool->capacity());
#ifdef PADDLE_WITH_PSLIB
uint64_t unuse_key = std::numeric_limits<uint64_t>::max();
for (size_t i = 0; i < len; ++i) {
if (device_keys[i] == unuse_key) {
continue;
}
size_t offset = i * feature_value_size;
FeatureValue* gpu_val = (FeatureValue*)(test_build_values + offset);
auto* downpour_value =
(paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr);
int downpour_value_size = downpour_value->size();
if (gpu_val->mf_size > 0 && downpour_value_size == 8) {
downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size);
}
float* cpu_val = downpour_value->data();
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
delta_score_index()] = gpu_val->delta_score;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
show_index()] = gpu_val->show;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
click_index()] = gpu_val->clk;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
embed_w_index()] = gpu_val->lr;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
embed_g2sum_index()] = gpu_val->lr_g2sum;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
slot_index()] = gpu_val->slot;
if (gpu_val->mf_size > 0) {
for (int x = 0; x < gpu_val->mf_dim + 1; x++) {
cpu_val[x + 8] = gpu_val->mf[x];
}
}
}
#endif
free(test_build_values);
};
if (multi_mf_dim_) {
VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_;
size_t device_num = heter_devices_.size();
std::vector<std::thread> threads(device_num * multi_mf_dim_);
for (size_t i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i + j * device_num] = std::thread(dump_pool_to_cpu_func, i, j);
}
}
for (std::thread& t : threads) {
t.join();
}
}
if (keysize_max != 0) {
HeterPs_->end_pass();
}
for (size_t i = 0; i < hbm_pools_.size(); i++) {
delete hbm_pools_[i];
}
gpu_task_pool_.Push(current_task_);
current_task_ = nullptr;
gpu_free_channel_->Put(current_task_);
timer.Pause();
VLOG(0) << "EndPass end, cost time: " << timer.ElapsedSec() << "s";
VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s";
}
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
......@@ -936,8 +936,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
pull_gpups_timer.Start();
HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu,
static_cast<int>(total_length));
// PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
// "PullSparseGPU failed in GPUPS."));
pull_gpups_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
......@@ -945,6 +943,98 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length);
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GpuPs: PullSparse Only Support CUDAPlace Now."));
}
all_timer.Pause();
VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec()
<< " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PullSparse";
}
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
const int table_id,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const std::vector<int>& slot_dim,
const int hidden_size) {
VLOG(3) << "Begine Gpu Ps PullSparse";
platform::Timer all_timer;
platform::Timer pull_gpups_timer;
all_timer.Start();
size_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
size_t feature_value_size = 0;
feature_value_size = TYPEALIGN(
8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1));
VLOG(0) << "yxf pull sparse feature_value_size: " << feature_value_size;
#ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begine Gpu Ps PullSparse";
auto buf = memory::Alloc(place, total_length * feature_value_size);
FeatureValue* total_values_gpu = reinterpret_cast<FeatureValue*>(buf->ptr());
#endif
#ifdef PADDLE_WITH_XPU_KP
VLOG(3) << "Begine Xpu Ps PullSparse";
FeatureValue* total_values_gpu = nullptr;
xpu_malloc(reinterpret_cast<void**>(&total_values_gpu),
total_length * feature_value_size);
#endif
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in GpuPs now."));
} else if (platform::is_gpu_place(place)) {
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = place.GetDeviceId();
int devid_2_index = HeterPs_->get_index_by_devid(device_id);
LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(total_keys_tensor.mutable_data<int64_t>(
{int64_t(total_length), 1}, place));
// construct slot_level lod info
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*));
auto buf_length =
memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
auto buf_dim = memory::Alloc(place, slot_dim.size() * sizeof(int));
int* gpu_dim = reinterpret_cast<int*>(buf_dim->ptr());
cudaMemcpy(gpu_dim, slot_dim.data(), slot_dim.size() * sizeof(int),
cudaMemcpyHostToDevice);
this->CopyKeys(place, gpu_keys, total_keys, gpu_len,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length));
VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index
<< " len: " << total_length;
pull_gpups_timer.Start();
HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu,
total_length);
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length, gpu_dim);
pull_gpups_timer.Pause();
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU_KP
......@@ -1013,7 +1103,10 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
// #ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begin GPUPS PushSparseGrad";
auto buf = memory::Alloc(place, total_length * sizeof(FeaturePushValue));
size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto buf = memory::Alloc(place, total_length * grad_value_size);
VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_;
FeaturePushValue* total_grad_values_gpu =
reinterpret_cast<FeaturePushValue*>(buf->ptr());
if (platform::is_cpu_place(place)) {
......@@ -1027,8 +1120,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>());
VLOG(3) << "Begin copy grad tensor to gpups struct";
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, total_length, batch_size);
if (!multi_mf_dim_) {
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, total_length, batch_size);
} else {
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
total_length, batch_size, grad_value_size);
}
VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index
<< " len: " << total_length;
......
......@@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src,
}
}
__global__ void PullCopy(float** dest, const FeatureValue* src,
const int64_t* len, int slot_num, int total_len,
uint64_t** keys, uint64_t max_val_size, int* gpu_dim) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
FeatureValue* feature_value_ptr =
(FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
int mf_dim = gpu_dim[x] - 3;
if (*(keys[x] + y) == 0) {
*(dest[x] + y * (mf_dim + 3)) = 0;
*(dest[x] + y * (mf_dim + 3) + 1) = 0;
*(dest[x] + y * (mf_dim + 3) + 2) = 0;
} else {
*(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show;
*(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk;
*(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr;
}
if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = 0;
}
} else {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j];
}
}
}
}
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
const int64_t* len, int slot_num,
int total_len) {
......@@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
}
}
__global__ void PushCopyWithPool(FeaturePushValue* dest, float** src,
int64_t* len, int slot_num, uint64_t total_len,
int bs, int* slot_vector, int* mf_dim_vector,
size_t grad_value_size) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
FeaturePushValue* cur =
(FeaturePushValue*)((char*)dest + i * grad_value_size);
cur->slot = slot_vector[x];
int mf_dim = mf_dim_vector[x];
cur->mf_dim = mf_dim;
cur->show = *(src[x] + y * (mf_dim + 3));
cur->clk = *(src[x] + y * (mf_dim + 3) + 1);
cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
for (int j = 0; j < cur->mf_dim; j++) {
cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
}
}
}
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
......@@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
const int64_t total_length, int* gpu_dim) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys,
val_type_size_, gpu_dim);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint64_t* total_keys,
const int64_t* gpu_len, int slot_num,
......@@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length,
const int batch_size, size_t grad_value_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
auto buf_mf_dim_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
cudaMemcpy(gpu_values, grad_values.data(),
grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(),
total_length, batch_size, d_slot_vector, d_mf_dim_vector,
grad_value_size);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float min_bound, float max_bound,
float learning_rate, float initial_g2sum,
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
......@@ -54,6 +55,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSLIB
#include "afs_api.h"
#endif
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h" // NOLINT
#endif
namespace paddle {
namespace framework {
......@@ -95,12 +99,21 @@ class PSGPUWrapper {
PSGPUWrapper() {
HeterPs_ = NULL;
sleep_seconds_before_fail_exit_ = 300;
pull_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
pull_thread_pool_[i].reset(new ::ThreadPool(1));
}
hbm_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
hbm_thread_pool_[i].reset(new ::ThreadPool(1));
}
}
void PullSparse(const paddle::platform::Place& place, const int table_id,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const std::vector<int>& slot_dim, const int hidden_size);
void PullSparse(const paddle::platform::Place& place, const int table_id,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
......@@ -119,13 +132,23 @@ class PSGPUWrapper {
const FeatureValue* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size,
const int64_t total_length);
void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size,
const int64_t total_length, int* gpu_dim);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length, const int batch_size,
size_t grad_value_size);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
......@@ -428,6 +451,7 @@ class PSGPUWrapper {
std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread pre_build_threads_;
bool running_ = false;
std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
protected:
......
......@@ -26,6 +26,7 @@ template <typename T>
static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto embedding_size_vec = ctx.Attr<std::vector<int>>("size");
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
// GpuPSPS only supports float now
......@@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
#ifdef PADDLE_WITH_HETERPS
auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance();
gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths,
0);
embedding_size_vec, 0);
#endif
}
......
......@@ -737,7 +737,7 @@ def _pull_gpups_sparse(input,
for i in range(len(inputs))
]
w = helper.create_parameter(
attr=helper.param_attr, shape=[11], dtype=dtype, is_bias=False)
attr=helper.param_attr, shape=[size[0]], dtype=dtype, is_bias=False)
helper.append_op(
type='pull_gpups_sparse',
inputs={'Ids': inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册