/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" namespace optimizer_config { extern _global_ptr_ float* nonclk_coeff; extern _global_ptr_ float* clk_coeff; extern _global_ptr_ float* min_bound; extern _global_ptr_ float* max_bound; extern _global_ptr_ float* learning_rate; extern _global_ptr_ float* initial_g2sum; extern _global_ptr_ float* initial_range; extern _global_ptr_ float* mf_create_thresholds; extern _global_ptr_ float* mf_learning_rate; extern _global_ptr_ float* mf_initial_g2sum; extern _global_ptr_ float* mf_initial_range; extern _global_ptr_ float* mf_min_bound; extern _global_ptr_ float* mf_max_bound; } namespace paddle { namespace framework { #if defined(PADDLE_WITH_XPU_KP) __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT float scale) { __local__ float local_learning_rate; __local__ float local_initial_g2sum; __local__ float local_min_bound; __local__ float local_max_bound; GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float)); GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float)); GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float)); GM2LM(optimizer_config::max_bound, &local_max_bound, sizeof(float)); double add_g2sum = 0; double ratio = local_learning_rate * sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum)); double scaled_grad = g / scale; w += scaled_grad * ratio; if (w < local_min_bound) w = local_min_bound; if (w > local_max_bound) w = local_max_bound; add_g2sum += scaled_grad * scaled_grad; g2sum += add_g2sum; } __device__ void update_mf(int n, float* w, float& g2sum, const float* g, float scale) { __local__ float local_mf_learning_rate; __local__ float local_mf_initial_g2sum; __local__ float local_mf_min_bound; __local__ float local_mf_max_bound; GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate, sizeof(float)); GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum, sizeof(float)); GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float)); GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float)); double add_g2sum = 0; double ratio = local_mf_learning_rate * sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + g2sum)); for (int i = 0; i < n; ++i) { double scaled_grad = g[i] / scale; w[i] += scaled_grad * ratio; if (w[i] < local_mf_min_bound) w[i] = local_mf_min_bound; if (w[i] > local_mf_max_bound) w[i] = local_mf_max_bound; add_g2sum += scaled_grad * scaled_grad; } g2sum += add_g2sum / n; } __device__ float xpu_rand_uniform() { return 0.1; } template __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT val.slot = grad.slot; val.show += grad.show; val.clk += grad.clk; __local__ float local_nonclk_coeff; __local__ float local_clk_coeff; __local__ float local_mf_create_thresholds; __local__ float local_mf_initial_range; GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float)); GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float)); GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds, sizeof(float)); val.delta_score += local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk; update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show); if (val.mf_size == 0) { if (local_mf_create_thresholds <= local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) { val.mf_size = MF_DIM + 1; val.mf[0] = 0; for (int i = 0; i < MF_DIM; ++i) { val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range; } } } else { update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show); } } template __global__ void insert_kernel(Table* table, const KeyType* const keys, const ValType* const vals, long long len) { int cid = core_id(); int ncores = core_num(); if (cid >= ncores) { return; } int thread_id = ncores * cluster_id() + cid; int nthreads = ncores * cluster_num(); const int buf_size = 150; __local__ KeyType local_keys[buf_size]; __local__ ValType local_vals[buf_size]; int len_per_loop = min(buf_size, roundup_div(len, nthreads)); for (int i = thread_id * len_per_loop; i < len; i += nthreads * len_per_loop) { int read_len = min(len_per_loop, len - i); GM2LM(keys, local_keys, read_len * sizeof(KeyType)); GM2LM(vals, local_vals, read_len * sizeof(ValType)); for (int k = 0; k < read_len; k++) { // auto status = table->insert(local_keys[k], local_vals[k]); // assert(status != false && "error: insert fails: table is full"); } } } template __global__ void search_kernel(Table* table, const KeyType* const keys, ValType* const vals, long long len) { int cid = core_id(); int ncores = core_num(); if (cid >= ncores) { return; } int thread_id = ncores * cluster_id() + cid; int nthreads = ncores * cluster_num(); const int buf_size = 150; __local__ KeyType local_keys[buf_size]; __local__ ValType local_vals[buf_size]; int len_per_loop = min(buf_size, roundup_div(len, nthreads)); for (int i = thread_id * len_per_loop; i < len; i += nthreads * len_per_loop) { int read_len = min(len_per_loop, len - i); GM2LM(keys, local_keys, read_len * sizeof(KeyType)); for (int k = 0; k < read_len; k++) { // ValType* val = table->find(local_keys[k]); // if (val != NULL) { // local_vals[k] = *val; // } } LM2GM(local_vals, vals + i, read_len * sizeof(ValType)); } } template __global__ void update_kernel(Table* table, const KeyType* const keys, const GradType* const grads, long long len) { int cid = core_id(); int ncores = core_num(); if (cid >= ncores) { return; } int thread_id = ncores * cluster_id() + cid; int nthreads = ncores * cluster_num(); const int buf_size = 250; __local__ KeyType local_keys[buf_size]; __local__ GradType local_grads[buf_size]; int len_per_loop = min(buf_size, roundup_div(len, nthreads)); for (int i = thread_id * len_per_loop; i < len; i += nthreads * len_per_loop) { int read_len = min(len_per_loop, len - i); GM2LM(keys, local_keys, read_len * sizeof(KeyType)); GM2LM(grads, local_grads, read_len * sizeof(GradType)); for (int k = 0; k < read_len; k++) { // ValType* val = table->find(local_keys[k]); // if (val != NULL) { // update_value(*val, grads[i]); //} } } } template HashTable::HashTable(size_t capacity) { auto tmp_container = XPUCacheArray(capacity); xpu_malloc(reinterpret_cast(&container_), sizeof(XPUCacheArray)); xpu_memcpy(container_, &tmp_container, sizeof(XPUCacheArray), XPU_HOST_TO_DEVICE); rwlock_.reset(new phi::RWLock); } template HashTable::~HashTable() { xpu_free((void*)container_); } template void HashTable::show() { container_->print(); } template template void HashTable::get(const KeyType* d_keys, ValType* d_vals, size_t len, StreamType stream) { if (len == 0) { return; } long long c_len = (long long)len; search_kernel><<<4, 64, stream>>>( container_, d_keys, d_vals, c_len); } template template void HashTable::get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream) { if (len == 0) { return; } // TODO(zhangminxu): to be implemented } template template void HashTable::insert(const KeyType* d_keys, const ValType* d_vals, size_t len, StreamType stream) { if (len == 0) { return; } long long c_len = (long long)len; insert_kernel><<<4, 64, stream>>>( container_, d_keys, d_vals, c_len); } template template void HashTable::dump_to_cpu(int devid, StreamType stream) { // TODO(zhangminxu): to be implemented } template template void HashTable::update(const KeyType* d_keys, const GradType* d_grads, size_t len, StreamType stream) { if (len == 0) { return; } long long c_len = (long long)len; update_kernel, GradType><<<4, 64, stream>>>(container_, d_keys, d_grads, c_len); } template template void HashTable::update(const KeyType* d_keys, const char* d_grads, size_t len, StreamType stream) { if (len == 0) { return; } // TODO(zhangminxu): to be implemented } template class HashTable; template void HashTable::get< XPUStream>(const unsigned long* d_keys, paddle::framework::FeatureValue* d_vals, size_t len, XPUStream stream); // template void // HashTable::get( // const unsigned long* d_keys, char* d_vals, size_t len, XPUStream stream); template void HashTable::insert< XPUStream>(const unsigned long* d_keys, const paddle::framework::FeatureValue* d_vals, size_t len, XPUStream stream); // template void HashTable::insert< // XPUStream>(const unsigned long* d_keys, size_t len, char* pool, // size_t start_index, XPUStream stream); template void HashTable:: dump_to_cpu(int devid, XPUStream stream); template void HashTable::update< paddle::framework::FeaturePushValue, XPUStream>( const unsigned long* d_keys, const paddle::framework::FeaturePushValue* d_grads, size_t len, XPUStream stream); // template void HashTable::update< // XPUStream>(const unsigned long* d_keys, const char* d_grads, // size_t len, XPUStream stream); #endif } // end namespace framework } // end namespace paddle #endif