未验证 提交 d9c174d1 编写于 作者: Y yaoxuefeng 提交者: GitHub

add hashtable dynamic mf support (#38493)

add hashtable dynamic mf support
上级 7411dab5
......@@ -27,6 +27,8 @@ limitations under the License. */
#include "thrust/pair.h"
// #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
......@@ -53,8 +55,11 @@ class HashTable {
HashTable& operator=(const HashTable&) = delete;
void insert(const KeyType* d_keys, const ValType* d_vals, size_t len,
gpuStream_t stream);
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
gpuStream_t stream);
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
gpuStream_t stream);
void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream);
void show();
void dump_to_cpu(int devid, cudaStream_t stream);
......@@ -62,8 +67,20 @@ class HashTable {
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream);
template <typename Sgd>
void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd,
gpuStream_t stream);
int size() { return container_->size(); }
void set_feature_value_size(size_t pull_feature_value_size,
size_t push_grad_value_size) {
pull_feature_value_size_ = pull_feature_value_size;
push_grad_value_size_ = push_grad_value_size;
VLOG(3) << "hashtable set pull value size: " << pull_feature_value_size_
<< " push value size: " << push_grad_value_size_;
}
std::unique_ptr<RWLock> rwlock_{nullptr};
private:
......@@ -71,6 +88,9 @@ class HashTable {
int BLOCK_SIZE_{256};
float LOAD_FACTOR{0.75f};
size_t capacity_;
size_t max_mf_dim_ = 8;
size_t pull_feature_value_size_;
size_t push_grad_value_size_;
};
} // end namespace framework
} // end namespace paddle
......
......@@ -42,6 +42,23 @@ __global__ void insert_kernel(Table* table,
}
}
template <typename Table>
__global__ void insert_kernel(Table* table,
const typename Table::key_type* const keys,
size_t len, char* pool, int start_index) {
ReplaceOp<typename Table::mapped_type> op;
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
kv.first = keys[i];
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
auto it = table->insert(kv, op);
assert(it != table->end() && "error: insert fails: table is full");
}
}
template <typename Table>
__global__ void search_kernel(Table* table,
const typename Table::key_type* const keys,
......@@ -56,6 +73,20 @@ __global__ void search_kernel(Table* table,
}
}
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys,
char* const vals, size_t len,
size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
}
}
}
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
const typename Table::key_type* const keys,
......@@ -70,6 +101,23 @@ __global__ void update_kernel(Table* table,
}
}
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
const typename Table::key_type* const keys,
const char* const grads, size_t len,
Sgd sgd, size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value((it.getter())->second, *cur);
} else {
printf("yxf::push miss key: %d", keys[i]);
}
}
}
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
container_ = new TableContainer<KeyType, ValType>(capacity);
......@@ -97,6 +145,17 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
d_vals, len);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
size_t len, gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_vals, len, pull_feature_value_size_);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
const ValType* d_vals, size_t len,
......@@ -109,6 +168,21 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals, len);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index,
gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
if (pool == NULL) {
return;
}
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
pool, start_index);
}
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
container_->prefetch(cudaCpuDeviceId, stream);
......@@ -166,6 +240,20 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
d_grads, len, sgd);
}
template <typename KeyType, typename ValType>
template <typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const char* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -96,6 +96,40 @@ class Optimizer {
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
}
}
__device__ void dy_mf_update_value(ValType* ptr, const GradType& grad) {
ptr->slot = grad.slot;
ptr->show += grad.show;
ptr->clk += grad.clk;
ptr->delta_score +=
optimizer_config::nonclk_coeff * (grad.show - grad.clk) +
optimizer_config::clk_coeff * grad.clk;
update_lr(ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show);
// use MF_DIM temporarily
// ptr->mf_dim = grad.mf_dim;
if (ptr->mf_size == 0) {
if (optimizer_config::mf_create_thresholds <=
optimizer_config::nonclk_coeff * (ptr->show - ptr->clk) +
optimizer_config::clk_coeff * ptr->clk) {
// ptr->mf_size = ptr->mf_dim + 1;
ptr->mf_size = MF_DIM + 1;
ptr->mf[0] = 0;
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
curandState state;
curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < MF_DIM; ++i) {
ptr->mf[i + 1] =
(curand_uniform(&state)) * optimizer_config::mf_initial_range;
}
}
} else {
update_mf(MF_DIM, &(ptr->mf[1]), ptr->mf[0], grad.mf_g,
grad.show); // for local test
}
}
};
} // end namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册