// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License 0// // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "brpc/channel.h" #include "brpc/controller.h" #include "brpc/server.h" #include "paddle/fluid/distributed/ps/service/ps_client.h" namespace paddle { namespace distributed { class Table; class PsLocalClient : public PSClient { public: PsLocalClient() {} virtual ~PsLocalClient() { _running = false; } virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms, int pslib_connect_timeout_ms, int max_retry) { return 0; } ::std::future Shrink(uint32_t table_id, const std::string threshold) override; ::std::future Load(const std::string& epoch, const std::string& mode) override; ::std::future Load(uint32_t table_id, const std::string& epoch, const std::string& mode) override; ::std::future Save(const std::string& epoch, const std::string& mode) override; ::std::future Save(uint32_t table_id, const std::string& epoch, const std::string& mode) override; ::std::future Clear() override; ::std::future Clear(uint32_t table_id) override; ::std::future StopServer() override; void FinalizeWorker() override {} virtual ::std::future PullDense(Region* regions, size_t region_num, size_t table_id); virtual ::std::future PushDense(const Region* regions, size_t region_num, size_t table_id); virtual ::std::future PushDenseParam(const Region* regions, size_t region_num, size_t table_id); virtual ::std::future PullSparse(float** select_values, size_t table_id, const uint64_t* keys, size_t num, bool is_training) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } virtual ::std::future PullSparsePtr(int shard_id, char** select_values, size_t table_id, const uint64_t* keys, size_t num, uint16_t pass_id); virtual ::std::future PrintTableStat(uint32_t table_id); virtual ::std::future SaveCacheTable(uint32_t table_id, uint16_t pass_id, size_t threshold); virtual ::std::future PushSparse(size_t table_id, const uint64_t* keys, const float** update_values, size_t num); virtual ::std::future Flush(); // server profilera virtual std::future StartProfiler() { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } virtual std::future StopProfiler() { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } virtual std::future Barrier(size_t table_id, uint32_t barrier_type) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } virtual std::future PullGeoParam(size_t table_id, std::vector* values, std::vector* keys, int pserver_idx) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } virtual std::future PushGlobalStep(int table_id, int64_t* total_send_data, void* done) { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } // recv table from server and save it in LodTensor virtual int32_t RecvAndSaveTable(const uint64_t table_id, const std::string& path) { return 0; } ::std::future SendClient2ClientMsg(int msg_type, int to_client_id, const std::string& msg) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } virtual size_t GetServerNums() { return 1; } std::future PushDenseRawGradient(int table_id, float* total_send_data, size_t total_send_data_size, void* callback) override; std::future PushSparseRawGradient(size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* callback) override; std::future PushSparseRawGradientPartial(size_t table_id, const uint64_t* keys, const float** update_values, uint32_t num, void* done, int pserver_idx) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } std::future PushSparseParam(size_t table_id, const uint64_t* keys, const float** update_values, size_t num, void* done) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); return fut; } private: int32_t Initialize() override; std::future done() { std::shared_ptr> prom = std::make_shared>(); std::future fut = prom->get_future(); prom->set_value(0); return fut; } inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, uint32_t shard_num) { return dense_dim_total / shard_num + 1; } inline std::unordered_map>* GetTable() { return &_table_map; } inline Table* GetTable(size_t table_id) { auto itr = _table_map.find(table_id); if (itr != _table_map.end()) { return itr->second.get(); } LOG(ERROR) << "table not found " << table_id; return NULL; } std::unordered_map> _table_map; bool _running = false; bool _flushing = false; private: float _mae = 0; float _mse = 0; uint16_t _push_times = 0; }; } // namespace distributed } // namespace paddle