// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h" #include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" #include "paddle/utils/flags.h" namespace paddle { namespace distributed { class MemRegion { public: MemRegion() { _cap = 2 * 1024 * 1024; _buf = reinterpret_cast(malloc(_cap)); _cur = 0; _file_idx = -1; } virtual ~MemRegion() { free(_buf); } bool buff_remain(int len) { if (_cap - _cur < len) { return false; } else { return true; } } char* acquire(int len) { if (_cap - _cur < len) { return nullptr; } else { char* ret = _buf + _cur; _cur += len; return ret; } } void reset() { _cur = 0; _file_idx = -1; } int _cap; int _cur; int _file_idx; char* _buf; }; class SSDSparseTable : public MemorySparseTable { public: typedef SparseTableShard shard_type; SSDSparseTable() {} virtual ~SSDSparseTable() {} int32_t Initialize() override; int32_t InitializeShard() override; // exchange data int32_t UpdateTable(); int32_t Pull(TableContext& context) override; int32_t Push(TableContext& context) override; int32_t PullSparse(float* pull_values, const uint64_t* keys, size_t num); int32_t PullSparsePtr(int shard_id, char** pull_values, const uint64_t* keys, size_t num, uint16_t pass_id); int32_t PushSparse(const uint64_t* keys, const float* values, size_t num); int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); int32_t Flush() override { return 0; } int32_t Shrink(const std::string& param) override; void Clear() override { for (int i = 0; i < _real_local_shard_num; ++i) { _local_shards[i].clear(); } } int32_t Save(const std::string& path, const std::string& param) override; int32_t SaveWithString(const std::string& path, const std::string& param); int32_t SaveWithStringMultiOutput(const std::string& path, const std::string& param); int32_t SaveWithBinary(const std::string& path, const std::string& param); int32_t SaveCache( const std::string& path, const std::string& param, paddle::framework::Channel>& shuffled_channel) override; double GetCacheThreshold() override { return _local_show_threshold; } int64_t CacheShuffle( const std::string& path, const std::string& param, double cache_threshold, std::function( int msg_type, int to_pserver_id, std::string& msg)> send_msg_func, paddle::framework::Channel>& shuffled_channel, const std::vector& table_ptrs) override; // 加载path目录下数据 int32_t Load(const std::string& path, const std::string& param) override; int32_t LoadWithString(size_t file_start_idx, size_t end_idx, const std::vector& file_list, const std::string& param); int32_t LoadWithBinary(const std::string& path, int param); int64_t LocalSize(); std::pair PrintTableStat() override; int32_t CacheTable(uint16_t pass_id) override; private: RocksDBHandler* _db; int64_t _cache_tk_size; double _local_show_threshold{0.0}; std::vector> _fs_channel; std::mutex _table_mutex; }; } // namespace distributed } // namespace paddle