/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_PSLIB namespace paddle { namespace framework { template struct ReplaceOp { __host__ __device__ value_type operator()(value_type new_value, value_type old_value) { return new_value; } }; template __global__ void insert_kernel(Table* table, const typename Table::key_type* const keys, const typename Table::mapped_type* const vals, size_t len) { ReplaceOp op; thrust::pair kv; const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { kv.first = keys[i]; kv.second = vals[i]; auto it = table->insert(kv, op); assert(it != table->end() && "error: insert fails: table is full"); } } template __global__ void search_kernel(Table* table, const typename Table::key_type* const keys, typename Table::mapped_type* const vals, size_t len) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { vals[i] = it->second; } } } template __global__ void update_kernel(Table* table, const typename Table::key_type* const keys, const GradType* const grads, size_t len, Sgd sgd) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { sgd.update_value((it.getter())->second, grads[i]); } } } template HashTable::HashTable(size_t capacity) { container_ = new TableContainer(capacity); } template HashTable::~HashTable() { delete container_; } template void HashTable::show() { container_->print(); } template void HashTable::get(const KeyType* d_keys, ValType* d_vals, size_t len, cudaStream_t stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; search_kernel<<>>(container_, d_keys, d_vals, len); } template void HashTable::insert(const KeyType* d_keys, const ValType* d_vals, size_t len, cudaStream_t stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; insert_kernel<<>>(container_, d_keys, d_vals, len); } template template void HashTable::update(const KeyType* d_keys, const GradType* d_grads, size_t len, Sgd sgd, cudaStream_t stream) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; update_kernel<<>>(container_, d_keys, d_grads, len, sgd); } } // end namespace framework } // end namespace paddle #endif