ps_local_client.h 7.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
18
#include "paddle/fluid/distributed/ps/service/ps_client.h"
T
Thunderbrook 已提交
19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace distributed {

class Table;

class PsLocalClient : public PSClient {
 public:
  PsLocalClient() {}
  virtual ~PsLocalClient() { _running = false; }
Z
zhaocaibei123 已提交
29 30 31
  virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
                                                int pslib_connect_timeout_ms,
                                                int max_retry) {
T
Thunderbrook 已提交
32 33 34
    return 0;
  }

Z
zhaocaibei123 已提交
35
  virtual ::std::future<int32_t> Shrink(uint32_t table_id,
T
Thunderbrook 已提交
36
                                        const std::string threshold) override;
Z
zhaocaibei123 已提交
37
  virtual ::std::future<int32_t> Load(const std::string& epoch,
T
Thunderbrook 已提交
38
                                      const std::string& mode) override;
Z
zhaocaibei123 已提交
39
  virtual ::std::future<int32_t> Load(uint32_t table_id,
T
Thunderbrook 已提交
40 41 42
                                      const std::string& epoch,
                                      const std::string& mode) override;

Z
zhaocaibei123 已提交
43
  virtual ::std::future<int32_t> Save(const std::string& epoch,
T
Thunderbrook 已提交
44
                                      const std::string& mode) override;
Z
zhaocaibei123 已提交
45
  virtual ::std::future<int32_t> Save(uint32_t table_id,
T
Thunderbrook 已提交
46 47 48
                                      const std::string& epoch,
                                      const std::string& mode) override;

Z
zhaocaibei123 已提交
49 50
  virtual ::std::future<int32_t> Clear() override;
  virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
T
Thunderbrook 已提交
51

Z
zhaocaibei123 已提交
52
  virtual ::std::future<int32_t> StopServer() override;
T
Thunderbrook 已提交
53

Z
zhaocaibei123 已提交
54
  virtual void FinalizeWorker() override {}
55 56
  virtual ::std::future<int32_t> PullDense(Region* regions,
                                           size_t region_num,
Z
zhaocaibei123 已提交
57
                                           size_t table_id);
T
Thunderbrook 已提交
58

Z
zhaocaibei123 已提交
59
  virtual ::std::future<int32_t> PushDense(const Region* regions,
60 61
                                           size_t region_num,
                                           size_t table_id);
Y
yaoxuefeng 已提交
62

Z
zhaocaibei123 已提交
63 64 65
  virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
                                                size_t region_num,
                                                size_t table_id);
Y
yaoxuefeng 已提交
66

Z
zhaocaibei123 已提交
67 68
  virtual ::std::future<int32_t> PullSparse(float** select_values,
                                            size_t table_id,
69 70
                                            const uint64_t* keys,
                                            size_t num,
Z
zhaocaibei123 已提交
71
                                            bool is_training) {
T
Thunderbrook 已提交
72 73 74 75 76 77 78
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
79 80 81 82
  virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
                                               size_t table_id,
                                               const uint64_t* keys,
                                               size_t num);
T
Thunderbrook 已提交
83

Z
zhaocaibei123 已提交
84
  virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
T
Thunderbrook 已提交
85 86 87 88 89 90
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }
Z
zhaocaibei123 已提交
91 92 93 94
  virtual ::std::future<int32_t> PushSparse(size_t table_id,
                                            const uint64_t* keys,
                                            const float** update_values,
                                            size_t num);
T
Thunderbrook 已提交
95

Z
zhaocaibei123 已提交
96
  virtual ::std::future<int32_t> Flush();
T
Thunderbrook 已提交
97
  // server profilera
Z
zhaocaibei123 已提交
98
  virtual std::future<int32_t> StartProfiler() {
T
Thunderbrook 已提交
99 100 101 102 103 104 105
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  };

Z
zhaocaibei123 已提交
106
  virtual std::future<int32_t> StopProfiler() {
T
Thunderbrook 已提交
107 108 109 110 111 112 113
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
114
  virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
T
Thunderbrook 已提交
115 116 117 118 119 120 121
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
122 123 124 125
  virtual std::future<int32_t> PullGeoParam(size_t table_id,
                                            std::vector<float>* values,
                                            std::vector<uint64_t>* keys,
                                            int pserver_idx) {
T
Thunderbrook 已提交
126 127 128 129 130 131 132
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
133 134 135
  virtual std::future<int32_t> PushGlobalStep(int table_id,
                                              int64_t* total_send_data,
                                              void* done) {
T
Thunderbrook 已提交
136 137 138 139 140 141 142 143
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

  // recv table from server and save it in LodTensor
Z
zhaocaibei123 已提交
144 145
  virtual int32_t RecvAndSaveTable(const uint64_t table_id,
                                   const std::string& path) {
T
Thunderbrook 已提交
146 147 148
    return 0;
  }

Z
zhaocaibei123 已提交
149
  virtual ::std::future<int32_t> SendClient2ClientMsg(
T
Thunderbrook 已提交
150 151 152 153 154 155 156
      int msg_type, int to_client_id, const std::string& msg) override {
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }
Z
zhaocaibei123 已提交
157
  virtual size_t GetServerNums() { return 1; }
T
Thunderbrook 已提交
158

Z
zhaocaibei123 已提交
159 160 161 162
  virtual std::future<int32_t> PushDenseRawGradient(int table_id,
                                                    float* total_send_data,
                                                    size_t total_send_data_size,
                                                    void* callback) override;
T
Thunderbrook 已提交
163

Z
zhaocaibei123 已提交
164
  virtual std::future<int32_t> PushSparseRawGradient(
165 166 167 168 169
      size_t table_id,
      const uint64_t* keys,
      const float** update_values,
      size_t num,
      void* callback) override;
T
Thunderbrook 已提交
170

Z
zhaocaibei123 已提交
171
  virtual std::future<int32_t> PushSparseRawGradientPartial(
172 173 174 175 176 177
      size_t table_id,
      const uint64_t* keys,
      const float** update_values,
      uint32_t num,
      void* done,
      int pserver_idx) override {
T
Thunderbrook 已提交
178 179 180 181 182 183 184
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
185 186 187 188 189
  virtual std::future<int32_t> PushSparseParam(size_t table_id,
                                               const uint64_t* keys,
                                               const float** update_values,
                                               size_t num,
                                               void* done) override {
T
Thunderbrook 已提交
190 191 192 193 194 195 196 197
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

 private:
Z
zhaocaibei123 已提交
198
  virtual int32_t Initialize() override;
T
Thunderbrook 已提交
199 200 201 202 203 204 205 206 207

  std::future<int32_t> done() {
    std::shared_ptr<std::promise<int32_t>> prom =
        std::make_shared<std::promise<int32_t>>();
    std::future<int32_t> fut = prom->get_future();
    prom->set_value(0);
    return fut;
  }

Z
zhaocaibei123 已提交
208 209
  inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
                                   uint32_t shard_num) {
T
Thunderbrook 已提交
210 211 212
    return dense_dim_total / shard_num + 1;
  }

Z
zhaocaibei123 已提交
213
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
T
Thunderbrook 已提交
214 215 216
    return &_table_map;
  }

Z
zhaocaibei123 已提交
217
  inline Table* GetTable(size_t table_id) {
T
Thunderbrook 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    LOG(ERROR) << "table not found " << table_id;
    return NULL;
  }

  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;

  bool _running = false;
  bool _flushing = false;

 private:
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
};
236 237
}  // namespace distributed
}  // namespace paddle