/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_HETERPS #include #include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" namespace paddle { namespace framework { #if defined(PADDLE_WITH_CUDA) template struct ReplaceOp { __host__ __device__ value_type operator()(value_type new_value, value_type old_value) { return new_value; } }; template __global__ void insert_kernel(Table* table, const typename Table::key_type* const keys, const typename Table::mapped_type* const vals, size_t len) { ReplaceOp op; thrust::pair kv; const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { kv.first = keys[i]; kv.second = vals[i]; auto it = table->insert(kv, op); assert(it != table->end() && "error: insert fails: table is full"); } } template __global__ void insert_kernel(Table* table, const typename Table::key_type* const keys, size_t len, char* pool, int start_index) { ReplaceOp op; thrust::pair kv; const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { kv.first = keys[i]; kv.second = (Table::mapped_type)(pool + (start_index + i) * 80); auto it = table->insert(kv, op); assert(it != table->end() && "error: insert fails: table is full"); } } template __global__ void search_kernel(Table* table, const typename Table::key_type* const keys, typename Table::mapped_type* const vals, size_t len) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { vals[i] = it->second; } } } template __global__ void dy_mf_search_kernel(Table* table, const typename Table::key_type* const keys, char* const vals, size_t len, size_t pull_feature_value_size) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second); } } } template __global__ void update_kernel(Table* table, const typename Table::key_type* const keys, const GradType* const grads, size_t len, Sgd sgd) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { sgd.update_value((it.getter())->second, grads[i]); } } } template __global__ void dy_mf_update_kernel(Table* table, const typename Table::key_type* const keys, const char* const grads, size_t len, Sgd sgd, size_t grad_value_size) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size); sgd.dy_mf_update_value((it.getter())->second, *cur); } else { printf("yxf::push miss key: %d", keys[i]); } } } template HashTable::HashTable(size_t capacity) { container_ = new TableContainer(capacity); rwlock_.reset(new phi::RWLock); } template HashTable::~HashTable() { delete container_; } template void HashTable::show() { container_->print(); } template template void HashTable::get(const KeyType* d_keys, ValType* d_vals, size_t len, StreamType stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; search_kernel<<>>(container_, d_keys, d_vals, len); } template template void HashTable::get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; dy_mf_search_kernel<<>>( container_, d_keys, d_vals, len, pull_feature_value_size_); } template template void HashTable::insert(const KeyType* d_keys, const ValType* d_vals, size_t len, StreamType stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; insert_kernel<<>>(container_, d_keys, d_vals, len); } template template void HashTable::insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index, StreamType stream) { if (len == 0) { return; } if (pool == NULL) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; insert_kernel<<>>(container_, d_keys, len, pool, start_index); } template template void HashTable::dump_to_cpu(int devid, StreamType stream) { container_->prefetch(cudaCpuDeviceId, stream); std::vector threads; size_t num = container_->size(); KeyType unuse_key = std::numeric_limits::max(); thrust::pair* kv = container_->data(); int thread_num = 8; int len_per_thread = num / thread_num; int remain = num % thread_num; int begin = 0; auto dump_func = [unuse_key, kv](int left, int right) { for (int i = left; i < right; i++) { if (kv[i].first == unuse_key) { continue; } ValType& gpu_val = kv[i].second; #ifdef PADDLE_WITH_PSLIB auto* downpour_value = (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr); int downpour_value_size = downpour_value->size(); if (gpu_val.mf_size > 0 && downpour_value_size == 7) { downpour_value->resize(gpu_val.mf_size + downpour_value_size); } float* cpu_val = downpour_value->data(); // cpu_val[0] = 0; cpu_val[1] = gpu_val.delta_score; cpu_val[2] = gpu_val.show; cpu_val[3] = gpu_val.clk; cpu_val[4] = gpu_val.lr; cpu_val[5] = gpu_val.lr_g2sum; cpu_val[6] = gpu_val.slot; if (gpu_val.mf_size > 0) { for (int x = 0; x < gpu_val.mf_size; x++) { cpu_val[x + 7] = gpu_val.mf[x]; } } #endif #ifdef PADDLE_WITH_PSCORE auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr); int downpour_value_size = downpour_value->size(); if (gpu_val.mf_size > 0 && downpour_value_size == 7) { downpour_value->resize(gpu_val.mf_size + downpour_value_size); } float* cpu_val = downpour_value->data(); // cpu_val[0] = 0; cpu_val[2] = gpu_val.delta_score; cpu_val[3] = gpu_val.show; cpu_val[4] = gpu_val.clk; cpu_val[5] = gpu_val.lr; cpu_val[6] = gpu_val.lr_g2sum; cpu_val[0] = gpu_val.slot; if (gpu_val.mf_size > 0) { for (int x = 0; x < gpu_val.mf_size; x++) { cpu_val[x + 7] = gpu_val.mf[x]; } } #endif } }; for (int i = 0; i < thread_num; i++) { threads.push_back(std::thread( dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0))); begin += len_per_thread + (i < remain ? 1 : 0); } for (std::thread& t : threads) { t.join(); } // container_->prefetch(devid, stream); } template template void HashTable::update(const KeyType* d_keys, const GradType* d_grads, size_t len, Sgd sgd, StreamType stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; update_kernel<<>>(container_, d_keys, d_grads, len, sgd); } template template void HashTable::update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd, StreamType stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; dy_mf_update_kernel<<>>( container_, d_keys, d_grads, len, sgd, push_grad_value_size_); } template class HashTable; template class HashTable; template class HashTable; template class HashTable; template void HashTable::get< cudaStream_t>(const unsigned long* d_keys, paddle::framework::FeatureValue* d_vals, size_t len, cudaStream_t stream); template void HashTable::get(const long* d_keys, int* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream); // template void // HashTable::get( // const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t // stream); template void HashTable::insert< cudaStream_t>(const unsigned long* d_keys, const paddle::framework::FeatureValue* d_vals, size_t len, cudaStream_t stream); template void HashTable::insert(const long* d_keys, const int* d_vals, size_t len, cudaStream_t stream); template void HashTable::insert( const long* d_keys, const unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::insert( const long* d_keys, const unsigned int* d_vals, size_t len, cudaStream_t stream); // template void HashTable::insert< // cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool, // size_t start_index, cudaStream_t stream); template void HashTable:: dump_to_cpu(int devid, cudaStream_t stream); template void HashTable::update< paddle::framework::FeaturePushValue, Optimizer, cudaStream_t>(const unsigned long* d_keys, const paddle::framework::FeaturePushValue* d_grads, size_t len, Optimizer sgd, cudaStream_t stream); // template void HashTable::update< // Optimizer, // cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t // len, // Optimizer // sgd, // cudaStream_t stream); #endif } // end namespace framework } // end namespace paddle #endif