ps_local_client.h 8.3 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
18
#include "paddle/fluid/distributed/ps/service/ps_client.h"
T
Thunderbrook 已提交
19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace distributed {

class Table;

class PsLocalClient : public PSClient {
 public:
  PsLocalClient() {}
  virtual ~PsLocalClient() { _running = false; }
Z
zhaocaibei123 已提交
29 30 31
  virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
                                                int pslib_connect_timeout_ms,
                                                int max_retry) {
T
Thunderbrook 已提交
32 33 34
    return 0;
  }

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  ::std::future<int32_t> Shrink(uint32_t table_id,
                                const std::string threshold) override;
  ::std::future<int32_t> Load(const std::string& epoch,
                              const std::string& mode) override;
  ::std::future<int32_t> Load(uint32_t table_id,
                              const std::string& epoch,
                              const std::string& mode) override;

  ::std::future<int32_t> Save(const std::string& epoch,
                              const std::string& mode) override;
  ::std::future<int32_t> Save(uint32_t table_id,
                              const std::string& epoch,
                              const std::string& mode) override;

  ::std::future<int32_t> Clear() override;
  ::std::future<int32_t> Clear(uint32_t table_id) override;

  ::std::future<int32_t> StopServer() override;

  void FinalizeWorker() override {}
55 56
  virtual ::std::future<int32_t> PullDense(Region* regions,
                                           size_t region_num,
Z
zhaocaibei123 已提交
57
                                           size_t table_id);
T
Thunderbrook 已提交
58

Z
zhaocaibei123 已提交
59
  virtual ::std::future<int32_t> PushDense(const Region* regions,
60 61
                                           size_t region_num,
                                           size_t table_id);
Y
yaoxuefeng 已提交
62

Z
zhaocaibei123 已提交
63 64 65
  virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
                                                size_t region_num,
                                                size_t table_id);
Y
yaoxuefeng 已提交
66

Z
zhaocaibei123 已提交
67 68
  virtual ::std::future<int32_t> PullSparse(float** select_values,
                                            size_t table_id,
69 70
                                            const uint64_t* keys,
                                            size_t num,
Z
zhaocaibei123 已提交
71
                                            bool is_training) {
T
Thunderbrook 已提交
72 73 74 75 76 77 78
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

L
lxsbupt 已提交
79 80
  virtual ::std::future<int32_t> PullSparsePtr(int shard_id,
                                               char** select_values,
Z
zhaocaibei123 已提交
81 82
                                               size_t table_id,
                                               const uint64_t* keys,
L
lxsbupt 已提交
83 84
                                               size_t num,
                                               uint16_t pass_id);
T
Thunderbrook 已提交
85

L
lxsbupt 已提交
86 87 88 89 90
  virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id);

  virtual ::std::future<int32_t> SaveCacheTable(uint32_t table_id,
                                                uint16_t pass_id,
                                                size_t threshold);
T
Thunderbrook 已提交
91

Z
zhaocaibei123 已提交
92 93 94 95
  virtual ::std::future<int32_t> PushSparse(size_t table_id,
                                            const uint64_t* keys,
                                            const float** update_values,
                                            size_t num);
T
Thunderbrook 已提交
96

Z
zhaocaibei123 已提交
97
  virtual ::std::future<int32_t> Flush();
T
Thunderbrook 已提交
98
  // server profilera
Z
zhaocaibei123 已提交
99
  virtual std::future<int32_t> StartProfiler() {
T
Thunderbrook 已提交
100 101 102 103 104
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
105
  }
T
Thunderbrook 已提交
106

Z
zhaocaibei123 已提交
107
  virtual std::future<int32_t> StopProfiler() {
T
Thunderbrook 已提交
108 109 110 111 112 113 114
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
115
  virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
T
Thunderbrook 已提交
116 117 118 119 120 121 122
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
123 124 125 126
  virtual std::future<int32_t> PullGeoParam(size_t table_id,
                                            std::vector<float>* values,
                                            std::vector<uint64_t>* keys,
                                            int pserver_idx) {
T
Thunderbrook 已提交
127 128 129 130 131 132 133
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
134 135 136
  virtual std::future<int32_t> PushGlobalStep(int table_id,
                                              int64_t* total_send_data,
                                              void* done) {
T
Thunderbrook 已提交
137 138 139 140 141 142 143 144
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

  // recv table from server and save it in LodTensor
Z
zhaocaibei123 已提交
145 146
  virtual int32_t RecvAndSaveTable(const uint64_t table_id,
                                   const std::string& path) {
T
Thunderbrook 已提交
147 148 149
    return 0;
  }

150 151 152
  ::std::future<int32_t> SendClient2ClientMsg(int msg_type,
                                              int to_client_id,
                                              const std::string& msg) override {
T
Thunderbrook 已提交
153 154 155 156 157 158
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }
Z
zhaocaibei123 已提交
159
  virtual size_t GetServerNums() { return 1; }
T
Thunderbrook 已提交
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
  std::future<int32_t> PushDenseRawGradient(int table_id,
                                            float* total_send_data,
                                            size_t total_send_data_size,
                                            void* callback) override;

  std::future<int32_t> PushSparseRawGradient(size_t table_id,
                                             const uint64_t* keys,
                                             const float** update_values,
                                             size_t num,
                                             void* callback) override;

  std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
                                                    const uint64_t* keys,
                                                    const float** update_values,
                                                    uint32_t num,
                                                    void* done,
                                                    int pserver_idx) override {
T
Thunderbrook 已提交
178 179 180 181 182 183 184
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

185 186 187 188 189
  std::future<int32_t> PushSparseParam(size_t table_id,
                                       const uint64_t* keys,
                                       const float** update_values,
                                       size_t num,
                                       void* done) override {
T
Thunderbrook 已提交
190 191 192 193 194 195 196 197
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

 private:
198
  int32_t Initialize() override;
T
Thunderbrook 已提交
199 200 201 202 203 204 205 206 207

  std::future<int32_t> done() {
    std::shared_ptr<std::promise<int32_t>> prom =
        std::make_shared<std::promise<int32_t>>();
    std::future<int32_t> fut = prom->get_future();
    prom->set_value(0);
    return fut;
  }

Z
zhaocaibei123 已提交
208 209
  inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
                                   uint32_t shard_num) {
T
Thunderbrook 已提交
210 211 212
    return dense_dim_total / shard_num + 1;
  }

Z
zhaocaibei123 已提交
213
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
T
Thunderbrook 已提交
214 215 216
    return &_table_map;
  }

Z
zhaocaibei123 已提交
217
  inline Table* GetTable(size_t table_id) {
T
Thunderbrook 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    LOG(ERROR) << "table not found " << table_id;
    return NULL;
  }

  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;

  bool _running = false;
  bool _flushing = false;

 private:
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
};
236 237
}  // namespace distributed
}  // namespace paddle