未验证 提交 3f619290 编写于 作者: Y yaoxuefeng 提交者: GitHub

merge dymf branch (#42714)

merge dymf branch
上级 e726960a
...@@ -129,11 +129,6 @@ class HeterContext { ...@@ -129,11 +129,6 @@ class HeterContext {
for (size_t i = 0; i < feature_dim_keys_.size(); i++) { for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
feature_dim_keys_[i].resize(dim_num); feature_dim_keys_[i].resize(dim_num);
value_dim_ptr_[i].resize(dim_num); value_dim_ptr_[i].resize(dim_num);
if (i == 0) {
for (int j = 0; j < dim_num; j++) {
feature_dim_keys_[i][j].push_back(0);
}
}
} }
device_values_.resize(device_num); device_values_.resize(device_num);
device_dim_values_.resize(device_num); device_dim_values_.resize(device_num);
......
...@@ -32,17 +32,33 @@ struct FeatureValue { ...@@ -32,17 +32,33 @@ struct FeatureValue {
float lr; float lr;
float lr_g2sum; float lr_g2sum;
int mf_size; int mf_size;
float mf[MF_DIM + 1]; int mf_dim;
uint64_t cpu_ptr; uint64_t cpu_ptr;
float mf[0];
friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
<< " lr: " << val.lr << " mf_size: " << val.mf_size << " mf:"; << " lr: " << val.lr << " mf_dim: " << val.mf_dim
for (int i = 0; i < val.mf_size; ++i) { << "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:";
for (int i = 0; i < val.mf_dim + 1; ++i) {
out << " " << val.mf[i]; out << " " << val.mf[i];
} }
return out; return out;
} }
__device__ __forceinline__ void operator=(const FeatureValue& in) {
delta_score = in.delta_score;
show = in.show;
clk = in.clk;
slot = in.slot;
lr = in.lr;
lr_g2sum = in.lr_g2sum;
mf_size = in.mf_size;
mf_dim = in.mf_dim;
cpu_ptr = in.cpu_ptr;
for (int i = 0; i < mf_dim + 1; i++) {
mf[i] = in.mf[i];
}
}
}; };
struct FeaturePushValue { struct FeaturePushValue {
...@@ -50,20 +66,19 @@ struct FeaturePushValue { ...@@ -50,20 +66,19 @@ struct FeaturePushValue {
float clk; float clk;
int slot; int slot;
float lr_g; float lr_g;
float mf_g[MF_DIM]; int mf_dim;
float mf_g[0];
// __device__ __forceinline__ FeaturePushValue __device__ __forceinline__ void operator=(const FeaturePushValue& in) {
// operator+(const FeaturePushValue& a) const { show = in.show;
// FeaturePushValue out; clk = in.clk;
// out.slot = a.slot; slot = in.slot;
// out.show = a.show + show; lr_g = in.lr_g;
// out.clk = a.clk + clk; mf_dim = in.mf_dim;
// out.lr_g = a.lr_g + lr_g; for (int i = 0; i < mf_dim; i++) {
// for (int i = 0; i < MF_DIM; ++i) { mf_g[i] = in.mf_g[i];
// out.mf_g[i] = a.mf_g[i] + mf_g[i]; }
// } }
// return out;
// }
}; };
} // end namespace framework } // end namespace framework
......
...@@ -118,8 +118,8 @@ class HashTable { ...@@ -118,8 +118,8 @@ class HashTable {
StreamType stream); StreamType stream);
template <typename StreamType> template <typename StreamType>
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, void insert(const KeyType* d_keys, size_t len, char* pool,
StreamType stream); size_t feature_value_size, size_t start_index, StreamType stream);
template <typename StreamType> template <typename StreamType>
void get(const KeyType* d_keys, ValType* d_vals, size_t len, void get(const KeyType* d_keys, ValType* d_vals, size_t len,
......
...@@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table, ...@@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table,
template <typename Table> template <typename Table>
__global__ void insert_kernel(Table* table, __global__ void insert_kernel(Table* table,
const typename Table::key_type* const keys, const typename Table::key_type* const keys,
size_t len, char* pool, int start_index) { size_t len, char* pool, size_t feature_value_size,
int start_index) {
ReplaceOp<typename Table::mapped_type> op; ReplaceOp<typename Table::mapped_type> op;
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv; thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
...@@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table, ...@@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table,
if (i < len) { if (i < len) {
kv.first = keys[i]; kv.first = keys[i];
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80); uint64_t offset = uint64_t(start_index + i) * feature_value_size;
kv.second = (Table::mapped_type)(pool + offset);
auto it = table->insert(kv, op); auto it = table->insert(kv, op);
assert(it != table->end() && "error: insert fails: table is full"); assert(it != table->end() && "error: insert fails: table is full");
} }
...@@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table, ...@@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table,
template <typename Table> template <typename Table>
__global__ void dy_mf_search_kernel(Table* table, __global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys, const typename Table::key_type* const keys,
char* const vals, size_t len, char* vals, size_t len,
size_t pull_feature_value_size) { size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) { if (i < len) {
auto it = table->find(keys[i]); auto it = table->find(keys[i]);
if (it != table->end()) { if (it != table->end()) {
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second); uint64_t offset = i * pull_feature_value_size;
FeatureValue& cur = *(FeatureValue*)(vals + offset);
FeatureValue& input = *(FeatureValue*)(it->second);
} }
} }
} }
...@@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table, ...@@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table,
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size); FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur); sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
} else { } else {
printf("yxf::push miss key: %d", keys[i]); printf("warning: push miss key: %d", keys[i]);
} }
} }
} }
...@@ -201,7 +205,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, ...@@ -201,7 +205,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
template <typename StreamType> template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len, void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index, char* pool, size_t feature_value_size,
size_t start_index,
StreamType stream) { StreamType stream) {
if (len == 0) { if (len == 0) {
return; return;
...@@ -210,8 +215,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len, ...@@ -210,8 +215,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
return; return;
} }
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len, insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
pool, start_index); container_, d_keys, len, pool, feature_value_size, start_index);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
...@@ -319,6 +324,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys, ...@@ -319,6 +324,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
} }
template class HashTable<unsigned long, paddle::framework::FeatureValue>; template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
template class HashTable<long, int>; template class HashTable<long, int>;
template class HashTable<unsigned long, int>; template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>; template class HashTable<unsigned long, unsigned long>;
...@@ -331,6 +337,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get< ...@@ -331,6 +337,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
paddle::framework::FeatureValue* d_vals, size_t len, paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys, template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len, int* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
...@@ -354,6 +364,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert< ...@@ -354,6 +364,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
const paddle::framework::FeatureValue* d_vals, size_t len, const paddle::framework::FeatureValue* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
insert<cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
size_t feature_value_size, size_t start_index,
cudaStream_t stream);
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys, template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
const int* d_vals, const int* d_vals,
size_t len, size_t len,
...@@ -393,6 +408,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update< ...@@ -393,6 +408,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
sgd, sgd,
cudaStream_t stream); cudaStream_t stream);
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::update<
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>,
cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t len,
Optimizer<paddle::framework::FeatureValue,
paddle::framework::FeaturePushValue>
sgd,
cudaStream_t stream);
// template void HashTable<unsigned long, // template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update< // paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue, // Optimizer<paddle::framework::FeatureValue,
......
...@@ -15,10 +15,13 @@ limitations under the License. */ ...@@ -15,10 +15,13 @@ limitations under the License. */
#pragma once #pragma once
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/timer.h"
#include "thrust/pair.h" #include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP) #elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" // #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
...@@ -38,6 +41,9 @@ limitations under the License. */ ...@@ -38,6 +41,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
class HeterComm { class HeterComm {
public: public:
...@@ -50,9 +56,13 @@ class HeterComm { ...@@ -50,9 +56,13 @@ class HeterComm {
int* left, int* right, int gpu_num); int* left, int* right, int gpu_num);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); // NOLINT int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads,
size_t len, int& uniq_len);
void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
size_t chunk_size, int stream_num); size_t chunk_size, int stream_num);
void build_ps(int num, KeyType* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size, int stream_num);
void dump(); void dump();
void show_one_table(int gpu_num); void show_one_table(int gpu_num);
int get_index_by_devid(int devid); int get_index_by_devid(int devid);
...@@ -96,6 +106,11 @@ class HeterComm { ...@@ -96,6 +106,11 @@ class HeterComm {
nccl_inter_comms_ = inter_comms; nccl_inter_comms_ = inter_comms;
node_size_ = comm_size; node_size_ = comm_size;
} }
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
multi_mf_dim_ = multi_mf_dim;
max_mf_dim_ = max_mf_dim;
}
#endif #endif
bool need_transfer(int send_id, int receive_id) { bool need_transfer(int send_id, int receive_id) {
...@@ -114,8 +129,8 @@ class HeterComm { ...@@ -114,8 +129,8 @@ class HeterComm {
char* key_storage; char* key_storage;
char* val_storage; char* val_storage;
int sync; int sync;
int key_bytes_len; size_t key_bytes_len;
int val_bytes_len; size_t val_bytes_len;
int dev_num; int dev_num;
}; };
...@@ -206,12 +221,18 @@ class HeterComm { ...@@ -206,12 +221,18 @@ class HeterComm {
void destroy_storage(int start_index, int end_index); void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val); KeyType* src_key, GradType* src_val);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, char* src_val, size_t val_size);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
ValType* src_val); ValType* src_val);
void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right,
char* src_val, size_t val_size);
protected: protected:
using Table = HashTable<KeyType, ValType>; using Table = HashTable<KeyType, ValType>;
using PtrTable = HashTable<KeyType, ValType*>;
std::vector<Table*> tables_; std::vector<Table*> tables_;
std::vector<PtrTable*> ptr_tables_;
std::shared_ptr<HeterPsResource> resource_; std::shared_ptr<HeterPsResource> resource_;
std::vector<std::vector<Path>> path_; std::vector<std::vector<Path>> path_;
float load_factor_{0.75}; float load_factor_{0.75};
...@@ -221,6 +242,7 @@ class HeterComm { ...@@ -221,6 +242,7 @@ class HeterComm {
private: private:
int topo_aware_{0}; int topo_aware_{0};
std::vector<LocalStorage> storage_; std::vector<LocalStorage> storage_;
DynamicGradMerger merger_;
int feanum_{1800 * 2048}; int feanum_{1800 * 2048};
int multi_node_{0}; int multi_node_{0};
int node_size_; int node_size_;
...@@ -228,6 +250,8 @@ class HeterComm { ...@@ -228,6 +250,8 @@ class HeterComm {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
std::vector<ncclComm_t> nccl_inner_comms_; std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_; std::vector<ncclComm_t> nccl_inter_comms_;
int multi_mf_dim_{8};
int max_mf_dim_ = 8;
std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_; std::vector<std::shared_ptr<cub::CachingDeviceAllocator>> allocators_;
#endif #endif
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
#include <queue> #include <queue>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
...@@ -22,20 +23,31 @@ limitations under the License. */ ...@@ -22,20 +23,31 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::HeterComm( HeterComm<KeyType, ValType, GradType>::HeterComm(
size_t capacity, std::shared_ptr<HeterPsResource> resource) { size_t capacity, std::shared_ptr<HeterPsResource> resource) {
resource_ = resource; resource_ = resource;
storage_.resize(resource_->total_device()); storage_.resize(resource_->total_device());
multi_mf_dim_ = resource->multi_mf();
for (int i = 0; i < resource_->total_device(); ++i) { for (int i = 0; i < resource_->total_device(); ++i) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(resource_->dev_id(i));
allocators_.push_back(std::make_shared<cub::CachingDeviceAllocator>( allocators_.push_back(std::make_shared<cub::CachingDeviceAllocator>(
8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT
#endif #endif
if (!multi_mf_dim_) {
auto table = new Table(capacity / load_factor_); auto table = new Table(capacity / load_factor_);
tables_.push_back(table); tables_.push_back(table);
} else {
max_mf_dim_ = resource_->max_mf_dim();
size_t val_type_size = TYPEALIGN(
8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
size_t grad_type_size = TYPEALIGN(
8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto ptr_table = new PtrTable(capacity / load_factor_);
ptr_table->set_feature_value_size(val_type_size, grad_type_size);
ptr_tables_.push_back(ptr_table);
}
if (multi_node_) { if (multi_node_) {
storage_[i].init(feanum_, resource_->dev_id(i)); storage_[i].init(feanum_, resource_->dev_id(i));
} }
...@@ -238,79 +250,103 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index, ...@@ -238,79 +250,103 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index, void HeterComm<KeyType, ValType, GradType>::walk_to_dest(
int num, int* h_left, int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key,
int* h_right, char* src_val, size_t val_size) {
ValType* src_val) { int need_copy_val = 0;
if (src_val) {
need_copy_val = 1;
}
std::queue<CopyTask> que; std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
int size = path_[start_index][i].nodes_.size();
auto& node = path_[start_index][i].nodes_[0];
CopyTask t(&path_[start_index][i], 0);
que.push(t);
cudaMemcpyAsync(node.key_storage,
reinterpret_cast<char*>(src_key + h_left[i]),
node.key_bytes_len, cudaMemcpyDefault, node.in_stream);
if (need_copy_val) {
cudaMemcpyAsync(node.val_storage,
src_val + uint64_t(h_left[i]) * uint64_t(val_size),
node.val_bytes_len, cudaMemcpyDefault, node.in_stream);
}
}
while (!que.empty()) {
CopyTask& cur_task = que.front();
que.pop();
if (cur_task.path->nodes_[cur_task.step].sync) {
cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream);
}
if (cur_task.step != cur_task.path->nodes_.size() - 1) {
int cur_step = cur_task.step;
CopyTask c(cur_task.path, cur_step + 1);
que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage,
cur_task.path->nodes_[cur_step].key_storage,
cur_task.path->nodes_[cur_step + 1].key_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
if (need_copy_val) {
cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage,
cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step + 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step + 1].in_stream);
}
}
}
}
for (int i = 0; i < num; i++) { template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::walk_to_src(
int start_index, int gpu_num, int* h_left, int* h_right, char* src_val,
size_t val_size) {
std::queue<CopyTask> que;
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
} }
int cur_step = path_[start_index][i].nodes_.size() - 1; int cur_step = path_[start_index][i].nodes_.size() - 1;
auto& node = path_[start_index][i].nodes_[cur_step]; auto& node = path_[start_index][i].nodes_[cur_step];
auto src_dev_id = resource_->dev_id(i);
auto src_place = DevPlace(src_dev_id);
if (cur_step == 0) { if (cur_step == 0) {
auto dst_dev_id = resource_->dev_id(start_index); cudaMemcpyAsync(src_val + uint64_t(h_left[i]) * val_size,
auto dst_place = DevPlace(dst_dev_id); node.val_storage, node.val_bytes_len, cudaMemcpyDefault,
memory_copy(dst_place, reinterpret_cast<char*>(src_val + h_left[i]),
src_place, node.val_storage, node.val_bytes_len,
node.out_stream); node.out_stream);
} else { } else {
CopyTask t(&path_[start_index][i], cur_step - 1); CopyTask t(&path_[start_index][i], cur_step - 1);
que.push(t); que.push(t);
cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage,
auto dst_dev_id = node.val_storage,
resource_->dev_id(path_[start_index][i].nodes_[cur_step - 1].dev_num);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place,
path_[start_index][i].nodes_[cur_step - 1].val_storage,
src_place, node.val_storage,
path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, path_[start_index][i].nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
path_[start_index][i].nodes_[cur_step - 1].out_stream); path_[start_index][i].nodes_[cur_step - 1].out_stream);
} }
} }
while (!que.empty()) { while (!que.empty()) {
CopyTask& cur_task = que.front(); CopyTask& cur_task = que.front();
que.pop(); que.pop();
int cur_step = cur_task.step; int cur_step = cur_task.step;
if (cur_task.path->nodes_[cur_step].sync) { if (cur_task.path->nodes_[cur_step].sync) {
sync_stream(cur_task.path->nodes_[cur_step].out_stream); cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream);
} }
auto src_dev_id =
resource_->dev_id(cur_task.path->nodes_[cur_step].dev_num);
auto src_place = DevPlace(src_dev_id);
if (cur_step > 0) { if (cur_step > 0) {
CopyTask c(cur_task.path, cur_step - 1); CopyTask c(cur_task.path, cur_step - 1);
que.push(c); que.push(c);
cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage,
auto dst_dev_id = cur_task.path->nodes_[cur_step].val_storage,
resource_->dev_id(cur_task.path->nodes_[cur_step - 1].dev_num);
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place, cur_task.path->nodes_[cur_step - 1].val_storage,
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step - 1].val_bytes_len, cur_task.path->nodes_[cur_step - 1].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step - 1].out_stream); cur_task.path->nodes_[cur_step - 1].out_stream);
} else if (cur_step == 0) { } else if (cur_step == 0) {
int end_index = cur_task.path->nodes_.back().dev_num; int end_index = cur_task.path->nodes_.back().dev_num;
cudaMemcpyAsync(src_val + uint64_t(h_left[end_index]) * val_size,
auto dst_dev_id = resource_->dev_id(end_index); cur_task.path->nodes_[cur_step].val_storage,
auto dst_place = DevPlace(dst_dev_id);
memory_copy(dst_place,
reinterpret_cast<char*>(src_val + h_left[end_index]),
src_place, cur_task.path->nodes_[cur_step].val_storage,
cur_task.path->nodes_[cur_step].val_bytes_len, cur_task.path->nodes_[cur_step].val_bytes_len,
cudaMemcpyDefault,
cur_task.path->nodes_[cur_step].out_stream); cur_task.path->nodes_[cur_step].out_stream);
} }
} }
...@@ -318,15 +354,24 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index, ...@@ -318,15 +354,24 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_src(int start_index,
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
HeterComm<KeyType, ValType, GradType>::~HeterComm() { HeterComm<KeyType, ValType, GradType>::~HeterComm() {
if (!multi_mf_dim_) {
for (auto& table : tables_) { for (auto& table : tables_) {
delete table; delete table;
table = nullptr; table = nullptr;
} }
} else {
for (auto& table : ptr_tables_) {
delete table;
table = nullptr;
}
}
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::show_one_table(int num) { void HeterComm<KeyType, ValType, GradType>::show_one_table(int gpu_num) {
tables_[num]->show(); if (!multi_mf_dim_) {
tables_[gpu_num]->show();
}
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
...@@ -418,59 +463,165 @@ void HeterComm<KeyType, ValType, GradType>::build_ps( ...@@ -418,59 +463,165 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(
} }
} }
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
char* pool, size_t len,
size_t feature_value_size,
size_t chunk_size,
int stream_num) {
if (len <= 0) {
return;
}
int dev_id = resource_->dev_id(num);
DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id);
std::vector<memory::allocation::AllocationPtr> d_key_bufs;
ppStream streams[stream_num]; // NOLINT
for (int i = 0; i < stream_num; ++i) {
create_stream(&(streams[i]));
auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType));
d_key_bufs.push_back(std::move(d_k_buf));
}
int cur_len = 0;
int cur_stream = 0;
while (cur_len < len) {
cur_stream = cur_stream % stream_num;
auto cur_use_stream = streams[cur_stream];
#if defined(PADDLE_WITH_XPU_KP)
cur_use_stream = 0;
#endif
int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size;
auto dst_place = place;
auto src_place = platform::CPUPlace();
memory_copy(
dst_place, reinterpret_cast<char*>(d_key_bufs[cur_stream]->ptr()),
src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream);
ptr_tables_[num]->insert(
reinterpret_cast<KeyType*>(d_key_bufs[cur_stream]->ptr()), tmp_len,
pool, feature_value_size, cur_len, cur_use_stream);
cur_stream += 1;
cur_len += tmp_len;
}
for (int i = 0; i < stream_num; ++i) {
sync_stream(streams[i]);
destroy_stream(streams[i]);
}
}
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::merge_grad( void HeterComm<KeyType, ValType, GradType>::merge_grad(
int dev_num, KeyType* d_keys, GradType* d_grads, size_t len, int dev_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len) { // NOLINT int& uniq_len) { // NOLINT
int dev_id = resource_->dev_id(dev_num); int dev_id = resource_->dev_id(dev_num);
DevPlace place = DevPlace(dev_id); DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id); AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0); auto stream = resource_->local_stream(dev_num, 0);
size_t temp_storage_bytes; size_t temp_storage_bytes;
auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr()); KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType)); auto d_merge_grads = memory::Alloc(place, len * sizeof(GradType));
GradType* d_merge_grads_ptr = GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr()); reinterpret_cast<GradType*>(d_merge_grads->ptr());
heter_comm_kernel_->sort_pairs(NULL, temp_storage_bytes, d_keys, heter_comm_kernel_->sort_pairs(NULL, temp_storage_bytes, d_keys,
d_merge_keys_ptr, d_grads, d_merge_grads_ptr, d_merge_keys_ptr, d_grads, d_merge_grads_ptr,
len, 0, 8 * sizeof(KeyType), stream, false); len, 0, 8 * sizeof(KeyType), stream, false);
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
heter_comm_kernel_->sort_pairs( heter_comm_kernel_->sort_pairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false); d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false);
temp_storage_bytes = 0; temp_storage_bytes = 0;
auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int)); auto d_num_runs_out_mem = memory::Alloc(place, sizeof(int));
int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr()); int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr());
heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, d_merge_keys_ptr, heter_comm_kernel_->reduce_by_key(NULL, temp_storage_bytes, d_merge_keys_ptr,
d_keys, d_merge_grads_ptr, d_grads, d_keys, d_merge_grads_ptr, d_grads,
d_num_runs_out, len, stream, false); d_num_runs_out, len, stream, false);
if (d_temp_storage->size() < temp_storage_bytes) { if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL; d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes); d_temp_storage = memory::Alloc(place, temp_storage_bytes);
} }
heter_comm_kernel_->reduce_by_key( heter_comm_kernel_->reduce_by_key(
d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys, d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys,
d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false); d_merge_grads_ptr, d_grads, d_num_runs_out, len, stream, false);
auto dst_place = platform::CPUPlace(); auto dst_place = platform::CPUPlace();
auto src_place = place; auto src_place = place;
memory_copy(dst_place, &uniq_len, src_place, d_num_runs_out, sizeof(int), memory_copy(dst_place, &uniq_len, src_place, d_num_runs_out, sizeof(int),
stream); stream);
sync_stream(stream); sync_stream(stream);
} }
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len) {
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);
size_t temp_storage_bytes;
size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());
auto d_merge_grads = memory::Alloc(place, len * grad_value_size);
GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr());
auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1));
uint32_t* d_fea_num_info_ptr =
reinterpret_cast<uint32_t*>(d_fea_num_info->ptr());
uint32_t* d_index = (uint32_t*)&d_fea_num_info_ptr[len];
uint32_t* d_idx = (uint32_t*)&d_index[len];
int* d_merged_size = (int*)&d_idx[len];
int grid_size = (len - 1) / block_size_ + 1;
heter_comm_kernel_->fill_idx(d_idx, len, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len,
0, 8 * sizeof(KeyType), stream));
void* d_buff = NULL;
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_fea_num_info_ptr,
d_merged_size, len, stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys,
d_fea_num_info_ptr, d_merged_size, len, stream));
cudaMemcpyAsync((void*)&uniq_len, d_merged_size, sizeof(int),
cudaMemcpyDeviceToHost, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
assert(d_merged_size > 0);
uint32_t* d_offset = (uint32_t*)&d_index[len];
temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len,
stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset,
uniq_len, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
heter_comm_kernel_->merge_gradient(
d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
grad_value_size * uniq_len,
cudaMemcpyDeviceToDevice, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
}
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::split_input_to_shard( void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right,
...@@ -529,8 +680,6 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, ...@@ -529,8 +680,6 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
AnyDeviceGuard guard(dev_id); AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(num, 0); auto stream = resource_->local_stream(num, 0);
// int grid_size = (len - 1) / block_size_ + 1;
int h_left[total_device]; // NOLINT int h_left[total_device]; // NOLINT
int h_right[total_device]; // NOLINT int h_right[total_device]; // NOLINT
...@@ -562,10 +711,11 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, ...@@ -562,10 +711,11 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
auto d_idx = memory::Alloc(place, len * sizeof(int)); auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr()); int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
size_t val_type_size =
TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr()); KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * sizeof(ValType)); auto d_shard_vals = memory::Alloc(place, len * val_type_size);
ValType* d_shard_vals_ptr = reinterpret_cast<ValType*>(d_shard_vals->ptr()); ValType* d_shard_vals_ptr = reinterpret_cast<ValType*>(d_shard_vals->ptr());
split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num);
...@@ -589,9 +739,8 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, ...@@ -589,9 +739,8 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
continue; continue;
} }
create_storage(num, i, shard_len * sizeof(KeyType), create_storage(num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(ValType)); shard_len * val_type_size);
} }
walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL); walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL);
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
...@@ -600,13 +749,10 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, ...@@ -600,13 +749,10 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
} }
auto& node = path_[num][i].nodes_.back(); auto& node = path_[num][i].nodes_.back();
sync_stream(node.in_stream); sync_stream(node.in_stream);
AnyDeviceGuard guard(resource_->dev_id(i)); AnyDeviceGuard guard(resource_->dev_id(i));
ptr_tables_[i]->rwlock_->RDLock();
tables_[i]->rwlock_->RDLock(); ptr_tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1,
reinterpret_cast<ValType*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num)); resource_->remote_stream(i, num));
} }
...@@ -615,21 +761,18 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, ...@@ -615,21 +761,18 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
if (h_left[i] == -1) { if (h_left[i] == -1) {
continue; continue;
} }
tables_[i]->rwlock_->UNLock(); ptr_tables_[i]->rwlock_->UNLock();
} }
walk_to_src(num, total_device, h_left, h_right,
walk_to_src(num, total_device, h_left, h_right, d_shard_vals_ptr); reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
auto& node = path_[num][i].nodes_.front(); auto& node = path_[num][i].nodes_.front();
sync_stream(node.out_stream); sync_stream(node.out_stream);
} }
heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
heter_comm_kernel_->fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, stream);
stream);
sync_stream(stream); sync_stream(stream);
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
...@@ -653,6 +796,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num, ...@@ -653,6 +796,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
int total_device = resource_->total_device(); int total_device = resource_->total_device();
int dev_id = resource_->dev_id(dev_num); int dev_id = resource_->dev_id(dev_num);
size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
DevPlace place = DevPlace(dev_id); DevPlace place = DevPlace(dev_id);
AnyDeviceGuard guard(dev_id); AnyDeviceGuard guard(dev_id);
auto stream = resource_->local_stream(dev_num, 0); auto stream = resource_->local_stream(dev_num, 0);
...@@ -691,21 +836,19 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num, ...@@ -691,21 +836,19 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr()); KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType)); GradType* d_shard_grads_ptr;
GradType* d_shard_grads_ptr = auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
reinterpret_cast<GradType*>(d_shard_grads->ptr()); d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
int uniq_len = len; int uniq_len = len;
merge_grad(dev_num, d_keys, d_grads, len, uniq_len); dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len);
// int grid_size = (uniq_len - 1) / block_size_ + 1; int grid_size = (uniq_len - 1) / block_size_ + 1;
split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr,
dev_num); dev_num);
heter_comm_kernel_->dy_mf_fill_shard_grads(
heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys, d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len,
d_shard_grads_ptr, d_grads, d_idx_ptr, grad_value_size, stream);
uniq_len, stream);
sync_stream(stream); sync_stream(stream);
...@@ -721,12 +864,22 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num, ...@@ -721,12 +864,22 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
} }
if (!multi_mf_dim_) {
create_storage(dev_num, i, shard_len * sizeof(KeyType), create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType)); shard_len * sizeof(GradType));
} else {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * grad_value_size);
}
} }
if (!multi_mf_dim_) {
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr, walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr); d_shard_grads_ptr);
} else {
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
reinterpret_cast<char*>(d_shard_grads_ptr), grad_value_size);
}
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
...@@ -736,17 +889,28 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num, ...@@ -736,17 +889,28 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
sync_stream(node.in_stream); sync_stream(node.in_stream);
AnyDeviceGuard guard(resource_->dev_id(i)); AnyDeviceGuard guard(resource_->dev_id(i));
if (!multi_mf_dim_) {
tables_[i]->rwlock_->WRLock(); tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage), tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage), reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd, h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num)); resource_->remote_stream(i, dev_num));
} else {
ptr_tables_[i]->rwlock_->WRLock();
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num));
}
} }
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
sync_stream(resource_->remote_stream(i, dev_num)); sync_stream(resource_->remote_stream(i, dev_num));
if (h_left[i] != -1) { if (h_left[i] != -1) {
if (!multi_mf_dim_) {
tables_[i]->rwlock_->UNLock(); tables_[i]->rwlock_->UNLock();
} else {
ptr_tables_[i]->rwlock_->UNLock();
}
} }
} }
...@@ -1078,12 +1242,14 @@ void HeterComm<KeyType, ValType, GradType>::end_pass() { ...@@ -1078,12 +1242,14 @@ void HeterComm<KeyType, ValType, GradType>::end_pass() {
tables_[index]->dump_to_cpu(dev_id, stream); tables_[index]->dump_to_cpu(dev_id, stream);
}; };
if (!multi_mf_dim_) {
for (int i = 0; i < total_device; ++i) { for (int i = 0; i < total_device; ++i) {
threads.push_back(std::thread(dump_to_cpu_func, i)); threads.push_back(std::thread(dump_to_cpu_func, i));
} }
for (auto& t : threads) { for (auto& t : threads) {
t.join(); t.join();
} }
}
} }
// template <typename KeyType, typename ValType, typename GradType> // template <typename KeyType, typename ValType, typename GradType>
......
...@@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals, ...@@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
} }
} }
template <typename KeyType, typename GradType, typename T>
__global__ void dy_mf_fill_shard_grads_kernel(
KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads,
GradType* d_grads, T* idx, size_t len, size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_shard_keys[i] = d_keys[idx[i]];
*(GradType*)((char*)d_shard_grads + i * grad_value_size) =
*(GradType*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
}
}
__global__ void merge_gradients_kernel(const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
size_t grad_value_size,
DynamicGradMerger& merger_) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
uint32_t start = offset[i];
uint32_t num = fea_num[i];
int ori_index = index[start];
FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size);
FeaturePushValue& in =
*(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in);
}
}
}
template <typename ValType, typename T>
__global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
T* idx, size_t len, size_t val_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
uint64_t new_offset = uint64_t(idx[i]) * val_size;
*(ValType*)((char*)d_vals + new_offset) =
*(ValType*)((char*)d_shard_vals + i * val_size);
}
}
// cuda implemention of heter_comm_kernel.h // cuda implemention of heter_comm_kernel.h
template <typename T, typename StreamType> template <typename T, typename StreamType>
void HeterCommKernel::fill_idx(T* idx, long long len, void HeterCommKernel::fill_idx(T* idx, long long len,
...@@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage, ...@@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage,
debug_synchronous)); debug_synchronous));
} }
template <typename KeyType, typename GradType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_shard_grads(
KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads,
GradType* d_grads, T* idx, long long len, size_t grad_value_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_keys, d_keys, d_shard_grads, d_grads, idx, c_len,
grad_value_size);
}
template <typename StreamType>
void HeterCommKernel::merge_gradient(
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
offset, fea_num, index, input, output, n, grad_value_size, merger_);
}
template <typename ValType, typename T, typename StreamType>
void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals,
T* idx, long long len, size_t val_size,
const StreamType& stream) {
int grid_size = (len - 1) / block_size_ + 1;
size_t c_len = (size_t)len;
dy_mf_fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals, d_vals, idx, c_len, val_size);
}
template void HeterCommKernel::fill_idx<int, cudaStream_t>( template void HeterCommKernel::fill_idx<int, cudaStream_t>(
int* idx, long long len, const cudaStream_t& stream); int* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::fill_idx<uint32_t, cudaStream_t>(
uint32_t* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::calc_shard_offset<int, cudaStream_t>( template void HeterCommKernel::calc_shard_offset<int, cudaStream_t>(
int* idx, int* left, int* right, long long len, int total_devs, int* idx, int* left, int* right, long long len, int total_devs,
...@@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key< ...@@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key<
paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out, paddle::framework::FeaturePushValue* d_aggregates_out, int* d_num_runs_out,
int num_items, cudaStream_t stream, bool debug_synchronous); int num_items, cudaStream_t stream, bool debug_synchronous);
template void HeterCommKernel::dy_mf_fill_shard_grads<
unsigned long, paddle::framework::FeaturePushValue, int, cudaStream_t>(
unsigned long* d_shard_keys, unsigned long* d_keys,
paddle::framework::FeaturePushValue* d_shard_grads,
paddle::framework::FeaturePushValue* d_grads, int* idx, long long len,
size_t grad_value_size, const cudaStream_t& stream);
template void HeterCommKernel::merge_gradient<cudaStream_t>(
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);
template void HeterCommKernel::dy_mf_fill_dvals<paddle::framework::FeatureValue,
int, cudaStream_t>(
paddle::framework::FeatureValue* d_shard_vals,
paddle::framework::FeatureValue* d_vals, int* idx, long long len,
size_t val_size, const cudaStream_t& stream);
#endif #endif
} // namespace framework } // namespace framework
......
...@@ -27,6 +27,42 @@ limitations under the License. */ ...@@ -27,6 +27,42 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct DynamicGradMerger {
template <typename T>
CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
operator()(const T& a, const T& b) const {
T out;
out.slot = a.slot;
out.mf_dim = a.mf_dim;
out.show = a.show + b.show;
out.clk = a.clk + b.clk;
out.lr_g = a.lr_g + b.lr_g;
return out;
}
template <typename T>
__device__ __forceinline__ void update_one(T& output, const T& input) {
output.slot = input.slot;
output.show = input.show;
output.clk = input.clk;
output.mf_dim = input.mf_dim;
output.lr_g = input.lr_g;
for (int i = 0; i < output.mf_dim; ++i) {
output.mf_g[i] = input.mf_g[i];
}
}
template <typename T>
__device__ __forceinline__ void merge_one(T& output, const T& input) {
output.show += input.show;
output.clk += input.clk;
output.lr_g += input.lr_g;
for (int i = 0; i < input.mf_dim; ++i) {
output.mf_g[i] += input.mf_g[i];
}
}
};
class HeterCommKernel { class HeterCommKernel {
public: public:
HeterCommKernel() {} HeterCommKernel() {}
...@@ -80,6 +116,24 @@ class HeterCommKernel { ...@@ -80,6 +116,24 @@ class HeterCommKernel {
StreamType stream = NULL, bool debug_synchronous = false); StreamType stream = NULL, bool debug_synchronous = false);
template <typename KeyType, typename GradType, typename T,
typename StreamType>
void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys,
GradType* d_shard_grads, GradType* d_grads,
T* idx, long long len, size_t grad_value_size,
const StreamType& stream);
template <typename StreamType>
void merge_gradient(const uint32_t* offset, const uint32_t* fea_num,
const uint32_t* index, const char* input, char* output,
int n, size_t grad_value_size, DynamicGradMerger& merger_,
const StreamType& stream);
template <typename ValType, typename T, typename StreamType>
void dy_mf_fill_dvals(ValType* d_shard_vals, ValType* d_vals, T* idx,
long long len, size_t val_size,
const StreamType& stream);
private: private:
int block_size_{256}; int block_size_{256};
}; };
......
...@@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, ...@@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num);
} }
void HeterPs::build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) {
comm_->build_ps(num, h_keys, pool, len, feature_value_size, chunk_size,
stream_num);
}
int HeterPs::get_index_by_devid(int devid) { int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid); return comm_->get_index_by_devid(devid);
} }
...@@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms, ...@@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size);
} }
void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim);
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
#endif #endif
...@@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase { ...@@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase {
size_t len) override; size_t len) override;
void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len,
size_t chunk_size, int stream_num) override; size_t chunk_size, int stream_num) override;
void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) override;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms, void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, const std::vector<ncclComm_t>& inter_comms,
int comm_size) override; int comm_size) override;
void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override;
#endif #endif
void set_sparse_sgd(const OptimizerConfig& optimizer_config) override; void set_sparse_sgd(const OptimizerConfig& optimizer_config) override;
......
...@@ -35,11 +35,15 @@ class HeterPsBase { ...@@ -35,11 +35,15 @@ class HeterPsBase {
size_t len) = 0; size_t len) = 0;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0; size_t len, size_t chunk_size, int stream_num) = 0;
virtual void build_ps(int num, FeatureKey* h_keys, char* pool, size_t len,
size_t feature_value_size, size_t chunk_size,
int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0; virtual int get_index_by_devid(int devid) = 0;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
virtual void set_nccl_comm_and_size( virtual void set_nccl_comm_and_size(
const std::vector<ncclComm_t>& inner_comms, const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0; const std::vector<ncclComm_t>& inter_comms, int comm_size) = 0;
virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0;
#endif #endif
virtual void end_pass() = 0; virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0; virtual void show_one_table(int gpu_num) = 0;
......
...@@ -107,6 +107,8 @@ class HeterPsResource { ...@@ -107,6 +107,8 @@ class HeterPsResource {
int get_index_by_devid(int devid); int get_index_by_devid(int devid);
int dev_id(int num); int dev_id(int num);
void set_multi_mf(int multi_mf_dim, int max_mf_dim); void set_multi_mf(int multi_mf_dim, int max_mf_dim);
int multi_mf() { return multi_mf_dim_; }
int max_mf_dim() { return max_mf_dim_; }
ppStream local_stream(int dev_num, int stream_num); ppStream local_stream(int dev_num, int stream_num);
ppStream remote_stream(int dev_num, int stream_num); ppStream remote_stream(int dev_num, int stream_num);
......
...@@ -125,20 +125,21 @@ class Optimizer { ...@@ -125,20 +125,21 @@ class Optimizer {
if (optimizer_config.mf_create_thresholds <= if (optimizer_config.mf_create_thresholds <=
optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) + optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) +
optimizer_config.clk_coeff * ptr->clk) { optimizer_config.clk_coeff * ptr->clk) {
// ptr->mf_size = ptr->mf_dim + 1; ptr->mf_size = ptr->mf_dim + 1;
ptr->mf_size = MF_DIM + 1; // ptr->mf_size = MF_DIM + 1;
ptr->mf[0] = 0; ptr->mf[0] = 0;
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
curandState state; curandState state;
curand_init(clock64(), tid_x, 0, &state); curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < MF_DIM; ++i) { for (int i = 0; i < ptr->mf_dim; ++i) {
ptr->mf[i + 1] = ptr->mf[i + 1] =
(curand_uniform(&state)) * optimizer_config.mf_initial_range; (curand_uniform(&state)) * optimizer_config.mf_initial_range;
} }
} }
} else { } else {
update_mf(optimizer_config, MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g, update_mf(optimizer_config, ptr->mf_dim, &(ptr->mf[1]), ptr->mf[0],
grad.mf_g,
grad.show); // for local test grad.show); // for local test
} }
} }
......
...@@ -31,7 +31,6 @@ limitations under the License. */ ...@@ -31,7 +31,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <deque> #include <deque>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
...@@ -112,12 +111,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -112,12 +111,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
} else { } else {
gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_); gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_);
} }
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
std::vector<std::thread> threads; std::vector<std::thread> threads;
// data should be in input channel
if (!multi_mf_dim_) { if (!multi_mf_dim_) {
thread_keys_.resize(thread_keys_thread_num_); thread_keys_.resize(thread_keys_thread_num_);
for (int i = 0; i < thread_keys_thread_num_; i++) { for (int i = 0; i < thread_keys_thread_num_; i++) {
...@@ -141,11 +136,9 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -141,11 +136,9 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
std::string data_set_name = std::string(typeid(*dataset_).name()); std::string data_set_name = std::string(typeid(*dataset_).name());
if (data_set_name.find("SlotRecordDataset") != std::string::npos) { if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset";
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_); SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel(); auto input_channel = dataset->GetInputChannel();
VLOG(0) << "yxf::buildtask::inputslotchannle size: " VLOG(0) << "psgpu wrapperinputslotchannle size: " << input_channel->Size();
<< input_channel->Size();
const std::deque<SlotRecord>& vec_data = input_channel->GetData(); const std::deque<SlotRecord>& vec_data = input_channel->GetData();
total_len = vec_data.size(); total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_; len_per_thread = total_len / thread_keys_thread_num_;
...@@ -176,21 +169,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -176,21 +169,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
j < slot_offset[slot_offset_vector_[slot_idx] + 1]; j++) { j < slot_offset[slot_offset_vector_[slot_idx] + 1]; j++) {
int shard_id = feasign_v[j] % thread_keys_shard_num_; int shard_id = feasign_v[j] % thread_keys_shard_num_;
int dim_id = slot_index_vec_[slot_idx]; int dim_id = slot_index_vec_[slot_idx];
if (feasign_v[j] != 0) {
this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]); this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]);
} }
} }
} }
/*
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_dim_keys_[i][shard_id][0].insert(feasign);
}
} }
*/
}; };
for (int i = 0; i < thread_keys_thread_num_; i++) { for (int i = 0; i < thread_keys_thread_num_; i++) {
if (!multi_mf_dim_) { if (!multi_mf_dim_) {
...@@ -264,12 +248,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -264,12 +248,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
thread_dim_keys_[i][shard_num][dim_id].clear(); thread_dim_keys_[i][shard_num][dim_id].clear();
} }
}; };
// for (size_t i = 0; i < thread_keys_.size(); i++) {
// gpu_task->batch_add_keys(thread_keys_[i]);
// for (int j = 0; j < thread_keys_thread_num_; j++) {
// thread_keys_[i][j].clear();
// }
//}
for (int i = 0; i < thread_keys_shard_num_; ++i) { for (int i = 0; i < thread_keys_shard_num_; ++i) {
if (!multi_mf_dim_) { if (!multi_mf_dim_) {
threads.push_back(std::thread(merge_ins_func, i)); threads.push_back(std::thread(merge_ins_func, i));
...@@ -291,22 +269,17 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -291,22 +269,17 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
timeline.Pause(); timeline.Pause();
VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds."; VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds.";
if (!multi_mf_dim_) {
for (int i = 0; i < thread_keys_shard_num_; i++) {
VLOG(0) << "GpuPs shard: " << i << " key len: " << local_keys[i].size();
local_ptr[i].resize(local_keys[i].size());
}
} else {
for (int i = 0; i < thread_keys_shard_num_; i++) { for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) { for (int j = 0; j < multi_mf_dim_; j++) {
if (i == 0 && j == multi_mf_dim_ - 1) {
gpu_task->feature_dim_keys_[i][j].push_back(0);
}
VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j] VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j]
<< " key len: " << gpu_task->feature_dim_keys_[i][j].size(); << " key len: " << gpu_task->feature_dim_keys_[i][j].size();
gpu_task->value_dim_ptr_[i][j].resize( gpu_task->value_dim_ptr_[i][j].resize(
gpu_task->feature_dim_keys_[i][j].size()); gpu_task->feature_dim_keys_[i][j].size());
} }
} }
}
} }
void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
...@@ -353,85 +326,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -353,85 +326,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
#endif #endif
timeline.Start(); timeline.Start();
auto ptl_func = [this, &local_keys, &local_ptr, &fleet_ptr](int i) {
size_t key_size = local_keys[i].size();
int32_t status = -1;
#ifdef PADDLE_WITH_PSLIB
// auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
// reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
// local_keys[i].data(), key_size);
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
i, reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
bool flag = true;
tt.wait();
try {
status = tt.get();
} catch (const std::future_error& e) {
VLOG(0) << "Caught a future_error with code" << e.code()
<< ", Message:" << e.what();
}
if (status != 0) {
VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
flag = false;
cnt++;
}
if (cnt > 3) {
VLOG(0) << "fleet pull sparse failed, retry 3 times";
exit(-1);
}
if (flag) {
break;
}
}
#endif
#ifdef PADDLE_WITH_PSCORE
int32_t cnt = 0;
while (true) {
auto tt = fleet_ptr->worker_ptr_->PullSparsePtr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size);
bool flag = true;
tt.wait();
try {
status = tt.get();
} catch (const std::future_error& e) {
VLOG(0) << "Caught a future_error with code" << e.code()
<< ", Message:" << e.what();
}
if (status != 0) {
VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
flag = false;
cnt++;
}
if (cnt > 3) {
VLOG(0) << "fleet pull sparse failed, retry 3 times";
exit(-1);
}
if (flag) {
break;
}
}
#endif
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(300);
exit(-1);
} else {
VLOG(3) << "FleetWrapper Pull sparse to local done with table size: "
<< local_keys[i].size();
}
};
auto ptl_dynamic_mf_func = [this, &local_dim_keys, &local_dim_ptr, auto ptl_dynamic_mf_func = [this, &local_dim_keys, &local_dim_ptr,
&fleet_ptr](int i, int j) { &fleet_ptr](int i, int j) {
...@@ -478,21 +372,18 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -478,21 +372,18 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
} }
#endif #endif
}; };
if (!multi_mf_dim_) {
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(ptl_func, i);
}
} else {
threads.resize(thread_keys_shard_num_ * multi_mf_dim_); threads.resize(thread_keys_shard_num_ * multi_mf_dim_);
for (int i = 0; i < thread_keys_shard_num_; i++) { for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) { for (int j = 0; j < multi_mf_dim_; j++) {
threads[i * multi_mf_dim_ + j] = std::thread(ptl_dynamic_mf_func, i, j); task_futures.emplace_back(
} pull_thread_pool_[i]->enqueue(ptl_dynamic_mf_func, i, j));
} }
} }
for (std::thread& t : threads) { for (auto& f : task_futures) {
t.join(); f.wait();
} }
task_futures.clear();
timeline.Pause(); timeline.Pause();
VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec() VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec()
<< " seconds."; << " seconds.";
...@@ -509,16 +400,9 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -509,16 +400,9 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
std::vector<std::vector<std::pair<uint64_t, char*>>> pass_values; std::vector<std::vector<std::pair<uint64_t, char*>>> pass_values;
bool record_status = false; bool record_status = false;
#ifdef PADDLE_WITH_PSLIB
uint16_t pass_id = 0;
if (multi_node_) {
record_status = fleet_ptr->pslib_ptr_->_worker_ptr->take_sparse_record(
table_id_, pass_id, pass_values);
}
#endif
auto& device_task_keys = gpu_task->device_task_keys_; auto& device_task_keys = gpu_task->device_task_keys_;
auto& device_task_ptrs = gpu_task->device_task_ptr_; auto& device_task_ptrs = gpu_task->device_task_ptr_;
auto build_dynamic_mf_func = [this, device_num, &local_dim_keys, auto build_pull_dynamic_mf_func = [this, device_num, &local_dim_keys,
&local_dim_ptr, &device_dim_keys, &local_dim_ptr, &device_dim_keys,
&device_dim_ptr, &device_dim_ptr,
&device_dim_mutex](int i, int j) { &device_dim_mutex](int i, int j) {
...@@ -532,20 +416,16 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -532,20 +416,16 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
task_ptrs[shard].push_back(local_dim_ptr[i][j][k]); task_ptrs[shard].push_back(local_dim_ptr[i][j][k]);
} }
for (int dev = 0; dev < device_num; dev++) { for (int dev = 0; dev < device_num; dev++) {
for (int dim = 0; dim < multi_mf_dim_; dim++) { device_dim_mutex[dev][j]->lock();
device_dim_mutex[dev][dim]->lock();
int len = task_keys[dev].size(); int len = task_keys[dev].size();
int cur = device_dim_keys[dev][dim].size(); int cur = device_dim_keys[dev][j].size();
device_dim_keys[dev][dim].resize(device_dim_keys[dev][dim].size() + device_dim_keys[dev][j].resize(device_dim_keys[dev][j].size() + len);
len); device_dim_ptr[dev][j].resize(device_dim_ptr[dev][j].size() + len);
device_dim_ptr[dev][dim].resize(device_dim_ptr[dev][dim].size() + len);
for (int k = 0; k < len; ++k) { for (int k = 0; k < len; ++k) {
device_dim_keys[dev][dim][cur + k] = task_keys[dev][k]; device_dim_keys[dev][j][cur + k] = task_keys[dev][k];
device_dim_ptr[dev][dim][cur + k] = task_ptrs[dev][k]; device_dim_ptr[dev][j][cur + k] = task_ptrs[dev][k];
}
device_dim_mutex[dev][dim]->unlock();
} }
device_dim_mutex[dev][j]->unlock();
} }
#endif #endif
}; };
...@@ -697,7 +577,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -697,7 +577,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
for (int i = 0; i < thread_keys_shard_num_; i++) { for (int i = 0; i < thread_keys_shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) { for (int j = 0; j < multi_mf_dim_; j++) {
threads[i * multi_mf_dim_ + j] = threads[i * multi_mf_dim_ + j] =
std::thread(build_dynamic_mf_func, i, j); std::thread(build_pull_dynamic_mf_func, i, j);
} }
} }
for (std::thread& t : threads) { for (std::thread& t : threads) {
...@@ -727,22 +607,18 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -727,22 +607,18 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
std::vector<size_t> feature_keys_count(device_num); std::vector<size_t> feature_keys_count(device_num);
size_t size_max = 0; size_t size_max = 0;
if (!multi_mf_dim_) {
for (int i = 0; i < device_num; i++) {
feature_keys_count[i] = gpu_task->device_keys_[i].size();
VLOG(0) << i << " card contains feasign nums: " << feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
}
} else {
for (int i = 0; i < device_num; i++) { for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) { for (int j = 0; j < multi_mf_dim_; j++) {
feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size(); feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size();
VLOG(1) << i << " card with dynamic mf dim: " << index_dim_vec_[j]
<< " dim index: " << j << " contains feasign nums: "
<< gpu_task->device_dim_ptr_[i][j].size();
} }
VLOG(0) << i << " card with dynamic mf contains feasign nums: " VLOG(1) << i << " card with dynamic mf contains feasign nums total: "
<< feature_keys_count[i]; << feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]); size_max = std::max(size_max, feature_keys_count[i]);
} }
}
if (HeterPs_) { if (HeterPs_) {
delete HeterPs_; delete HeterPs_;
HeterPs_ = nullptr; HeterPs_ = nullptr;
...@@ -756,17 +632,73 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -756,17 +632,73 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
#endif #endif
auto build_func = [this, &gpu_task, &feature_keys_count](int i) { auto build_dynamic_mf_func = [this, &gpu_task](int i, int j) {
VLOG(3) << "building table: " << i; this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_);
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(), int mf_dim = this->index_dim_vec_[j];
gpu_task->device_values_[i].data(), VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim;
feature_keys_count[i], 500000, 2); size_t feature_value_size =
// if (feature_keys_count[i] > 0) { TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));
// HeterPs_->show_one_table(i); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
// } auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j];
size_t len = device_dim_keys.size();
CHECK(len == device_dim_ptrs.size());
this->mem_pools_[i * this->multi_mf_dim_ + j] =
new MemoryPool(len, feature_value_size);
auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j];
for (size_t k = 0; k < len; k++) {
FeatureValue* val = (FeatureValue*)(mem_pool->mem_address(k));
float* ptr_val = device_dim_ptrs[k]->data();
size_t dim = device_dim_ptrs[k]->size();
#ifdef PADDLE_WITH_PSLIB
val->delta_score =
ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::delta_score_index()];
val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::show_index()];
val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::click_index()];
val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::slot_index()]);
val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::embed_w_index()];
val->lr_g2sum =
ptr_val[paddle::ps::DownpourCtrDymfAccessor::
DownpourCtrDymfFeatureValue::embed_g2sum_index()];
val->cpu_ptr = (uint64_t)(device_dim_ptrs[k]);
ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
mf_dim_index()] = float(mf_dim);
val->mf_dim = mf_dim;
#endif
if (dim > 8) { // CpuPS alreay expand as mf_dim
val->mf_size = mf_dim + 1;
for (int x = 0; x < val->mf_dim + 1; x++) {
val->mf[x] = ptr_val[x + 8];
}
} else {
val->mf_size = 0;
for (int x = 0; x < val->mf_dim + 1; x++) {
val->mf[x] = 0;
}
}
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool);
auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len,
feature_value_size, 500000, 2);
if (device_dim_keys.size() > 0) {
VLOG(0) << "show ptr table: " << i
<< " table kv size: " << device_dim_keys.size()
<< "dim: " << mf_dim << " len: " << len;
this->HeterPs_->show_one_table(i);
}
delete mem_pool;
}; };
for (size_t i = 0; i < threads.size(); i++) { threads.resize(device_num * multi_mf_dim_);
threads[i] = std::thread(build_func, i); for (int i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j);
}
} }
for (std::thread& t : threads) { for (std::thread& t : threads) {
t.join(); t.join();
...@@ -788,7 +720,6 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { ...@@ -788,7 +720,6 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
if (is_shuffle) { if (is_shuffle) {
dataset_->LocalShuffle(); dataset_->LocalShuffle();
} }
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get(); std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset(); gpu_task->Reset();
data_ready_channel_->Put(gpu_task); data_ready_channel_->Put(gpu_task);
...@@ -874,17 +805,86 @@ void PSGPUWrapper::EndPass() { ...@@ -874,17 +805,86 @@ void PSGPUWrapper::EndPass() {
size_t keysize_max = 0; size_t keysize_max = 0;
// in case of feasign_num = 0, skip dump_to_cpu // in case of feasign_num = 0, skip dump_to_cpu
for (size_t i = 0; i < heter_devices_.size(); i++) { for (size_t i = 0; i < heter_devices_.size(); i++) {
keysize_max = std::max(keysize_max, current_task_->device_keys_[i].size()); for (int j = 0; j < multi_mf_dim_; j++) {
keysize_max =
std::max(keysize_max, current_task_->device_dim_keys_[i][j].size());
}
}
auto dump_pool_to_cpu_func = [this](int i, int j) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i)));
auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
auto& device_keys = this->current_task_->device_dim_keys_[i][j];
size_t len = device_keys.size();
int mf_dim = this->index_dim_vec_[j];
VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim;
size_t feature_value_size =
TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));
char* test_build_values = (char*)malloc(feature_value_size * len);
cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len,
cudaMemcpyDeviceToHost);
CHECK(len == hbm_pool->capacity());
#ifdef PADDLE_WITH_PSLIB
uint64_t unuse_key = std::numeric_limits<uint64_t>::max();
for (size_t i = 0; i < len; ++i) {
if (device_keys[i] == unuse_key) {
continue;
}
size_t offset = i * feature_value_size;
FeatureValue* gpu_val = (FeatureValue*)(test_build_values + offset);
auto* downpour_value =
(paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr);
int downpour_value_size = downpour_value->size();
if (gpu_val->mf_size > 0 && downpour_value_size == 8) {
downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size);
}
float* cpu_val = downpour_value->data();
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
delta_score_index()] = gpu_val->delta_score;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
show_index()] = gpu_val->show;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
click_index()] = gpu_val->clk;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
embed_w_index()] = gpu_val->lr;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
embed_g2sum_index()] = gpu_val->lr_g2sum;
cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
slot_index()] = gpu_val->slot;
if (gpu_val->mf_size > 0) {
for (int x = 0; x < gpu_val->mf_dim + 1; x++) {
cpu_val[x + 8] = gpu_val->mf[x];
}
}
}
#endif
free(test_build_values);
};
if (multi_mf_dim_) {
VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_;
size_t device_num = heter_devices_.size();
std::vector<std::thread> threads(device_num * multi_mf_dim_);
for (size_t i = 0; i < device_num; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads[i + j * device_num] = std::thread(dump_pool_to_cpu_func, i, j);
}
}
for (std::thread& t : threads) {
t.join();
}
} }
if (keysize_max != 0) { if (keysize_max != 0) {
HeterPs_->end_pass(); HeterPs_->end_pass();
} }
for (size_t i = 0; i < hbm_pools_.size(); i++) {
delete hbm_pools_[i];
}
gpu_task_pool_.Push(current_task_); gpu_task_pool_.Push(current_task_);
current_task_ = nullptr; current_task_ = nullptr;
gpu_free_channel_->Put(current_task_); gpu_free_channel_->Put(current_task_);
timer.Pause(); timer.Pause();
VLOG(0) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s";
} }
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
...@@ -936,8 +936,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, ...@@ -936,8 +936,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
pull_gpups_timer.Start(); pull_gpups_timer.Start();
HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu,
static_cast<int>(total_length)); static_cast<int>(total_length));
// PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
// "PullSparseGPU failed in GPUPS."));
pull_gpups_timer.Pause(); pull_gpups_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
...@@ -945,6 +943,98 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, ...@@ -945,6 +943,98 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size, static_cast<int>(slot_lengths.size()), hidden_size,
total_length); total_length);
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GpuPs: PullSparse Only Support CUDAPlace Now."));
}
all_timer.Pause();
VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec()
<< " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PullSparse";
}
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
const int table_id,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const std::vector<int>& slot_dim,
const int hidden_size) {
VLOG(3) << "Begine Gpu Ps PullSparse";
platform::Timer all_timer;
platform::Timer pull_gpups_timer;
all_timer.Start();
size_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
size_t feature_value_size = 0;
feature_value_size = TYPEALIGN(
8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1));
VLOG(0) << "yxf pull sparse feature_value_size: " << feature_value_size;
#ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begine Gpu Ps PullSparse";
auto buf = memory::Alloc(place, total_length * feature_value_size);
FeatureValue* total_values_gpu = reinterpret_cast<FeatureValue*>(buf->ptr());
#endif
#ifdef PADDLE_WITH_XPU_KP
VLOG(3) << "Begine Xpu Ps PullSparse";
FeatureValue* total_values_gpu = nullptr;
xpu_malloc(reinterpret_cast<void**>(&total_values_gpu),
total_length * feature_value_size);
#endif
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in GpuPs now."));
} else if (platform::is_gpu_place(place)) {
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = place.GetDeviceId();
int devid_2_index = HeterPs_->get_index_by_devid(device_id);
LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(total_keys_tensor.mutable_data<int64_t>(
{int64_t(total_length), 1}, place));
// construct slot_level lod info
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*));
auto buf_length =
memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
auto buf_dim = memory::Alloc(place, slot_dim.size() * sizeof(int));
int* gpu_dim = reinterpret_cast<int*>(buf_dim->ptr());
cudaMemcpy(gpu_dim, slot_dim.data(), slot_dim.size() * sizeof(int),
cudaMemcpyHostToDevice);
this->CopyKeys(place, gpu_keys, total_keys, gpu_len,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length));
VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index
<< " len: " << total_length;
pull_gpups_timer.Start();
HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu,
total_length);
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length, gpu_dim);
pull_gpups_timer.Pause();
#endif #endif
} else if (platform::is_xpu_place(place)) { } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
...@@ -1013,7 +1103,10 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, ...@@ -1013,7 +1103,10 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
// #ifdef PADDLE_WITH_CUDA // #ifdef PADDLE_WITH_CUDA
VLOG(3) << "Begin GPUPS PushSparseGrad"; VLOG(3) << "Begin GPUPS PushSparseGrad";
auto buf = memory::Alloc(place, total_length * sizeof(FeaturePushValue)); size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
auto buf = memory::Alloc(place, total_length * grad_value_size);
VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_;
FeaturePushValue* total_grad_values_gpu = FeaturePushValue* total_grad_values_gpu =
reinterpret_cast<FeaturePushValue*>(buf->ptr()); reinterpret_cast<FeaturePushValue*>(buf->ptr());
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
...@@ -1027,8 +1120,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, ...@@ -1027,8 +1120,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
uint64_t* total_keys = uint64_t* total_keys =
reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>()); reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>());
VLOG(3) << "Begin copy grad tensor to gpups struct"; VLOG(3) << "Begin copy grad tensor to gpups struct";
if (!multi_mf_dim_) {
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, total_length, batch_size); hidden_size, total_length, batch_size);
} else {
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
total_length, batch_size, grad_value_size);
}
VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index
<< " len: " << total_length; << " len: " << total_length;
......
...@@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src, ...@@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src,
} }
} }
__global__ void PullCopy(float** dest, const FeatureValue* src,
const int64_t* len, int slot_num, int total_len,
uint64_t** keys, uint64_t max_val_size, int* gpu_dim) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[x - 1] : 0);
FeatureValue* feature_value_ptr =
(FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
int mf_dim = gpu_dim[x] - 3;
if (*(keys[x] + y) == 0) {
*(dest[x] + y * (mf_dim + 3)) = 0;
*(dest[x] + y * (mf_dim + 3) + 1) = 0;
*(dest[x] + y * (mf_dim + 3) + 2) = 0;
} else {
*(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show;
*(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk;
*(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr;
}
if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = 0;
}
} else {
for (int j = 0; j < mf_dim; j++) {
*(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j];
}
}
}
}
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
const int64_t* len, int slot_num, const int64_t* len, int slot_num,
int total_len) { int total_len) {
...@@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len, ...@@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
} }
} }
__global__ void PushCopyWithPool(FeaturePushValue* dest, float** src,
int64_t* len, int slot_num, uint64_t total_len,
int bs, int* slot_vector, int* mf_dim_vector,
size_t grad_value_size) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
while (low < high) {
int mid = (low + high) / 2;
if (i < len[mid])
high = mid;
else
low = mid + 1;
}
int x = low;
int y = i - (x ? len[low - 1] : 0);
FeaturePushValue* cur =
(FeaturePushValue*)((char*)dest + i * grad_value_size);
cur->slot = slot_vector[x];
int mf_dim = mf_dim_vector[x];
cur->mf_dim = mf_dim;
cur->show = *(src[x] + y * (mf_dim + 3));
cur->clk = *(src[x] + y * (mf_dim + 3) + 1);
cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
for (int j = 0; j < cur->mf_dim; j++) {
cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
}
}
}
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; } PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
...@@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, ...@@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
} }
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
const int64_t total_length, int* gpu_dim) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys,
val_type_size_, gpu_dim);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
uint64_t** origin_keys, uint64_t* total_keys, uint64_t** origin_keys, uint64_t* total_keys,
const int64_t* gpu_len, int slot_num, const int64_t* gpu_len, int slot_num,
...@@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, ...@@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
} }
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length,
const int batch_size, size_t grad_value_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto slot_lengths_lod = slot_lengths;
for (int i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_grad_value =
memory::Alloc(place, grad_values.size() * sizeof(float*));
auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
auto buf_slot_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
auto buf_mf_dim_vector =
memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
cudaMemcpy(gpu_values, grad_values.data(),
grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(),
total_length, batch_size, d_slot_vector, d_mf_dim_vector,
grad_value_size);
cudaStreamSynchronize(stream);
}
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float min_bound, float max_bound, float min_bound, float max_bound,
float learning_rate, float initial_g2sum, float learning_rate, float initial_g2sum,
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include <vector> #include <vector>
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h> #include <gloo/broadcast.h>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif #endif
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h" #include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
...@@ -54,6 +55,9 @@ limitations under the License. */ ...@@ -54,6 +55,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
#include "afs_api.h" #include "afs_api.h"
#endif #endif
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h" // NOLINT
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -95,12 +99,21 @@ class PSGPUWrapper { ...@@ -95,12 +99,21 @@ class PSGPUWrapper {
PSGPUWrapper() { PSGPUWrapper() {
HeterPs_ = NULL; HeterPs_ = NULL;
sleep_seconds_before_fail_exit_ = 300; sleep_seconds_before_fail_exit_ = 300;
pull_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
pull_thread_pool_[i].reset(new ::ThreadPool(1));
}
hbm_thread_pool_.resize(thread_keys_shard_num_); hbm_thread_pool_.resize(thread_keys_shard_num_);
for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
hbm_thread_pool_[i].reset(new ::ThreadPool(1)); hbm_thread_pool_[i].reset(new ::ThreadPool(1));
} }
} }
void PullSparse(const paddle::platform::Place& place, const int table_id,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const std::vector<int>& slot_dim, const int hidden_size);
void PullSparse(const paddle::platform::Place& place, const int table_id, void PullSparse(const paddle::platform::Place& place, const int table_id,
const std::vector<const uint64_t*>& keys, const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values, const std::vector<float*>& values,
...@@ -119,13 +132,23 @@ class PSGPUWrapper { ...@@ -119,13 +132,23 @@ class PSGPUWrapper {
const FeatureValue* total_values_gpu, const int64_t* gpu_len, const FeatureValue* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size, const int slot_num, const int hidden_size,
const int64_t total_length); const int64_t total_length);
void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys,
const std::vector<float*>& values,
const FeatureValue* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size,
const int64_t total_length, int* gpu_dim);
void CopyForPush(const paddle::platform::Place& place, void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values, const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu, FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths, const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length, const int hidden_size, const int64_t total_length,
const int batch_size); const int batch_size);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
FeaturePushValue* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const uint64_t total_length, const int batch_size,
size_t grad_value_size);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task); void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task); void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
...@@ -428,6 +451,7 @@ class PSGPUWrapper { ...@@ -428,6 +451,7 @@ class PSGPUWrapper {
std::shared_ptr<HeterContext> current_task_ = nullptr; std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread pre_build_threads_; std::thread pre_build_threads_;
bool running_ = false; bool running_ = false;
std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_; std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
protected: protected:
......
...@@ -26,6 +26,7 @@ template <typename T> ...@@ -26,6 +26,7 @@ template <typename T>
static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids"); auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out"); auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto embedding_size_vec = ctx.Attr<std::vector<int>>("size");
const auto slot_size = inputs.size(); const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size); std::vector<const uint64_t *> all_keys(slot_size);
// GpuPSPS only supports float now // GpuPSPS only supports float now
...@@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { ...@@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance(); auto gpu_ps_ptr = paddle::framework::PSGPUWrapper::GetInstance();
gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths, gpu_ps_ptr->PullSparse(ctx.GetPlace(), 0, all_keys, all_values, slot_lengths,
0); embedding_size_vec, 0);
#endif #endif
} }
......
...@@ -737,7 +737,7 @@ def _pull_gpups_sparse(input, ...@@ -737,7 +737,7 @@ def _pull_gpups_sparse(input,
for i in range(len(inputs)) for i in range(len(inputs))
] ]
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, shape=[11], dtype=dtype, is_bias=False) attr=helper.param_attr, shape=[size[0]], dtype=dtype, is_bias=False)
helper.append_op( helper.append_op(
type='pull_gpups_sparse', type='pull_gpups_sparse',
inputs={'Ids': inputs, inputs={'Ids': inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册