// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/distributed/ps.pb.h" #include "paddle/fluid/distributed/service/env.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/table/accessor.h" namespace paddle { namespace distributed { class PSEnvironment; class PsRequestMessage; class PsResponseMessage; class ValueAccessor; struct Region; using paddle::distributed::PsRequestMessage; using paddle::distributed::PsResponseMessage; typedef std::function PSClientCallBack; class PSClientClosure : public google::protobuf::Closure { public: PSClientClosure(PSClientCallBack callback) : _callback(callback) {} virtual ~PSClientClosure() {} virtual void set_promise_value(int value) { for (auto &promise : _promises) { promise->set_value(value); } } void add_promise(std::shared_ptr> &promise) { _promises.push_back(promise); } protected: PSClientCallBack _callback; std::vector>> _promises; }; class PSClient { public: PSClient() {} virtual ~PSClient() {} PSClient(PSClient &&) = delete; PSClient(const PSClient &) = delete; virtual int32_t configure( const PSParameter &config, const std::map> ®ions, PSEnvironment &_env, size_t client_id) final; virtual int32_t create_client2client_connection( int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) = 0; // 触发table数据退场 virtual std::future shrink(uint32_t table_id) = 0; // 全量table进行数据load virtual std::future load(const std::string &epoch, const std::string &mode) = 0; // 指定table数据load virtual std::future load(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; // 全量table数据save value_accessor根据mode,可能有不同的save条件 virtual std::future save(const std::string &epoch, const std::string &mode) = 0; // 指定table数据save value_accessor根据mode,可能有不同的save条件 virtual std::future save(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; //清空table数据 virtual std::future clear() = 0; virtual std::future clear(uint32_t table_id) = 0; // pull dense的参数部分,并分块填充到本地网络参数中 // start和num用于拉取部分参数 // future结束前keys和values缓冲区不能再次使用 // client将values按照区块拆包后送交多个sender // sender聚集同一区块的请求,累计多个填充buffer // server将参数区块中配置的某一维提取返回 // 返回数据解包后填充到累计的多个buffer中 virtual std::future pull_dense(Region *regions, size_t region_num, size_t table_id) = 0; //保留 // firstly push dense param for parameter server // this is neccessary because dense weight initialized in trainer on cold // start virtual std::future push_dense_param(const Region *regions, size_t region_num, size_t table_id) = 0; // 使用keys进行pull请求,结果填充values // keys和values的个数均为num个,每个value占用select_size空间 // future结束前keys和values缓冲区不能再次使用 // 整合多个线程请求的keys,聚集并分散发送到server // 返回结果后,遍历buffer并对values赋值 virtual std::future pull_sparse(float **select_values, size_t table_id, const uint64_t *keys, size_t num) = 0; virtual std::future print_table_stat(uint32_t table_id) = 0; // 确保所有积攒中的请求都发起发送 virtual std::future flush() = 0; // server优雅退出 virtual std::future stop_server() = 0; // server profilera virtual std::future start_profiler() = 0; virtual std::future stop_profiler() = 0; virtual std::future barrier(size_t table_id, uint32_t barrier_type) = 0; virtual std::future pull_geo_param(size_t table_id, std::vector *values, std::vector *keys, int pserver_idx) = 0; virtual std::future push_global_step(int table_id, int64_t *total_send_data, void *done) = 0; // recv table from server and save it in LodTensor virtual int32_t recv_and_save_table(const uint64_t table_id, const std::string &path) = 0; virtual void finalize_worker() = 0; // client to client, 消息发送 virtual std::future send_client2client_msg(int msg_type, int to_client_id, const std::string &msg) { LOG(FATAL) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); promise.set_value(-1); return fut; } // client2client消息处理,std::function ret (msg_type, from_client_id, msg) typedef std::function MsgHandlerFunc; virtual int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) { _msg_handler_map[msg_type] = handler; return 0; } virtual int handle_client2client_msg(int msg_type, int from_client_id, const std::string &msg) { auto itr = _msg_handler_map.find(msg_type); if (itr == _msg_handler_map.end()) { LOG(WARNING) << "unknown client2client_msg type:" << msg_type; return -1; } return itr->second(msg_type, from_client_id, msg); } virtual ValueAccessor *table_accessor(size_t table_id) { auto itr = _table_accessors.find(table_id); if (itr == _table_accessors.end()) { return NULL; } return itr->second.get(); } virtual size_t get_server_nums() = 0; virtual std::future push_dense_raw_gradient( int table_id, float *total_send_data, size_t total_send_data_size, void *done) = 0; virtual std::future push_sparse_raw_gradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) = 0; virtual std::future push_sparse_raw_gradient_partial( size_t table_id, const uint64_t *keys, const float **update_values, uint32_t num, void *done, int pserver_idx) = 0; virtual std::future push_sparse_param(size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) = 0; protected: virtual int32_t initialize() = 0; size_t _client_id; PSParameter _config; std::map> _dense_pull_regions; PSEnvironment *_env; std::unordered_map> _table_accessors; std::unordered_map _msg_handler_map; //处理client2client消息 }; REGISTER_PSCORE_REGISTERER(PSClient); class PSClientFactory { public: static PSClient *create(const PSParameter &config); }; } // namespace distributed } // namespace paddle