ps_local_client.h 8.2 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
18
#include "paddle/fluid/distributed/ps/service/ps_client.h"
T
Thunderbrook 已提交
19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace distributed {

class Table;

class PsLocalClient : public PSClient {
 public:
  PsLocalClient() {}
  virtual ~PsLocalClient() { _running = false; }
Z
zhaocaibei123 已提交
29 30 31
  virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
                                                int pslib_connect_timeout_ms,
                                                int max_retry) {
T
Thunderbrook 已提交
32 33 34
    return 0;
  }

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  virtual ::std::future<int32_t> Shrink(uint32_t table_id,
                                        const std::string threshold);
  virtual ::std::future<int32_t> Load(const std::string& epoch,
                                      const std::string& mode);
  virtual ::std::future<int32_t> Load(uint32_t table_id,
                                      const std::string& epoch,
                                      const std::string& mode);

  virtual ::std::future<int32_t> Save(const std::string& epoch,
                                      const std::string& mode);
  virtual ::std::future<int32_t> Save(uint32_t table_id,
                                      const std::string& epoch,
                                      const std::string& mode);

  virtual ::std::future<int32_t> Clear();
  virtual ::std::future<int32_t> Clear(uint32_t table_id);

  virtual ::std::future<int32_t> StopServer();

  virtual void FinalizeWorker() {}
55 56
  virtual ::std::future<int32_t> PullDense(Region* regions,
                                           size_t region_num,
Z
zhaocaibei123 已提交
57
                                           size_t table_id);
T
Thunderbrook 已提交
58

Z
zhaocaibei123 已提交
59
  virtual ::std::future<int32_t> PushDense(const Region* regions,
60 61
                                           size_t region_num,
                                           size_t table_id);
Y
yaoxuefeng 已提交
62

Z
zhaocaibei123 已提交
63 64 65
  virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
                                                size_t region_num,
                                                size_t table_id);
Y
yaoxuefeng 已提交
66

Z
zhaocaibei123 已提交
67 68
  virtual ::std::future<int32_t> PullSparse(float** select_values,
                                            size_t table_id,
69 70
                                            const uint64_t* keys,
                                            size_t num,
Z
zhaocaibei123 已提交
71
                                            bool is_training) {
T
Thunderbrook 已提交
72 73 74 75 76 77 78
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

79
  virtual ::std::future<int32_t> PullSparsePtr(const int shard_id,
L
lxsbupt 已提交
80
                                               char** select_values,
Z
zhaocaibei123 已提交
81 82
                                               size_t table_id,
                                               const uint64_t* keys,
L
lxsbupt 已提交
83
                                               size_t num,
84 85
                                               uint16_t pass_id,
                                               const uint16_t& dim_id = 0);
T
Thunderbrook 已提交
86

L
lxsbupt 已提交
87 88 89 90 91
  virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id);

  virtual ::std::future<int32_t> SaveCacheTable(uint32_t table_id,
                                                uint16_t pass_id,
                                                size_t threshold);
T
Thunderbrook 已提交
92

Z
zhaocaibei123 已提交
93 94 95 96
  virtual ::std::future<int32_t> PushSparse(size_t table_id,
                                            const uint64_t* keys,
                                            const float** update_values,
                                            size_t num);
T
Thunderbrook 已提交
97

Z
zhaocaibei123 已提交
98
  virtual ::std::future<int32_t> Flush();
T
Thunderbrook 已提交
99
  // server profilera
Z
zhaocaibei123 已提交
100
  virtual std::future<int32_t> StartProfiler() {
T
Thunderbrook 已提交
101 102 103 104 105
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
106
  }
T
Thunderbrook 已提交
107

Z
zhaocaibei123 已提交
108
  virtual std::future<int32_t> StopProfiler() {
T
Thunderbrook 已提交
109 110 111 112 113 114 115
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
116
  virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
T
Thunderbrook 已提交
117 118 119 120 121 122 123
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
124 125 126 127
  virtual std::future<int32_t> PullGeoParam(size_t table_id,
                                            std::vector<float>* values,
                                            std::vector<uint64_t>* keys,
                                            int pserver_idx) {
T
Thunderbrook 已提交
128 129 130 131 132 133 134
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
135 136 137
  virtual std::future<int32_t> PushGlobalStep(int table_id,
                                              int64_t* total_send_data,
                                              void* done) {
T
Thunderbrook 已提交
138 139 140 141 142 143 144 145
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

  // recv table from server and save it in LodTensor
Z
zhaocaibei123 已提交
146 147
  virtual int32_t RecvAndSaveTable(const uint64_t table_id,
                                   const std::string& path) {
T
Thunderbrook 已提交
148 149 150
    return 0;
  }

151 152 153
  virtual ::std::future<int32_t> SendClient2ClientMsg(int msg_type,
                                                      int to_client_id,
                                                      const std::string& msg) {
T
Thunderbrook 已提交
154 155 156 157 158 159
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }
Z
zhaocaibei123 已提交
160
  virtual size_t GetServerNums() { return 1; }
T
Thunderbrook 已提交
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
  virtual std::future<int32_t> PushDenseRawGradient(int table_id,
                                                    float* total_send_data,
                                                    size_t total_send_data_size,
                                                    void* callback);

  virtual std::future<int32_t> PushSparseRawGradient(
      size_t table_id,
      const uint64_t* keys,
      const float** update_values,
      size_t num,
      void* callback);

  virtual std::future<int32_t> PushSparseRawGradientPartial(
      size_t table_id,
      const uint64_t* keys,
      const float** update_values,
      uint32_t num,
      void* done,
      int pserver_idx) {
T
Thunderbrook 已提交
181 182 183 184 185 186 187
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

188 189 190 191 192
  virtual std::future<int32_t> PushSparseParam(size_t table_id,
                                               const uint64_t* keys,
                                               const float** update_values,
                                               size_t num,
                                               void* done) {
T
Thunderbrook 已提交
193 194 195 196 197 198 199
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

200 201
 protected:
  virtual int32_t Initialize();
T
Thunderbrook 已提交
202 203 204 205 206 207 208 209 210

  std::future<int32_t> done() {
    std::shared_ptr<std::promise<int32_t>> prom =
        std::make_shared<std::promise<int32_t>>();
    std::future<int32_t> fut = prom->get_future();
    prom->set_value(0);
    return fut;
  }

Z
zhaocaibei123 已提交
211 212
  inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
                                   uint32_t shard_num) {
T
Thunderbrook 已提交
213 214 215
    return dense_dim_total / shard_num + 1;
  }

Z
zhaocaibei123 已提交
216
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
T
Thunderbrook 已提交
217 218 219
    return &_table_map;
  }

Z
zhaocaibei123 已提交
220
  inline Table* GetTable(size_t table_id) {
T
Thunderbrook 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    LOG(ERROR) << "table not found " << table_id;
    return NULL;
  }

  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;

  bool _running = false;
  bool _flushing = false;

 private:
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
};
239 240
}  // namespace distributed
}  // namespace paddle