ps_local_client.h 7.7 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
18
#include "paddle/fluid/distributed/ps/service/ps_client.h"
T
Thunderbrook 已提交
19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace distributed {

class Table;

class PsLocalClient : public PSClient {
 public:
  PsLocalClient() {}
  virtual ~PsLocalClient() { _running = false; }
Z
zhaocaibei123 已提交
29 30 31
  virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
                                                int pslib_connect_timeout_ms,
                                                int max_retry) {
T
Thunderbrook 已提交
32 33 34
    return 0;
  }

Z
zhaocaibei123 已提交
35
  virtual ::std::future<int32_t> Shrink(uint32_t table_id,
T
Thunderbrook 已提交
36
                                        const std::string threshold) override;
Z
zhaocaibei123 已提交
37
  virtual ::std::future<int32_t> Load(const std::string& epoch,
T
Thunderbrook 已提交
38
                                      const std::string& mode) override;
Z
zhaocaibei123 已提交
39
  virtual ::std::future<int32_t> Load(uint32_t table_id,
T
Thunderbrook 已提交
40 41 42
                                      const std::string& epoch,
                                      const std::string& mode) override;

Z
zhaocaibei123 已提交
43
  virtual ::std::future<int32_t> Save(const std::string& epoch,
T
Thunderbrook 已提交
44
                                      const std::string& mode) override;
Z
zhaocaibei123 已提交
45
  virtual ::std::future<int32_t> Save(uint32_t table_id,
T
Thunderbrook 已提交
46 47 48
                                      const std::string& epoch,
                                      const std::string& mode) override;

Z
zhaocaibei123 已提交
49 50
  virtual ::std::future<int32_t> Clear() override;
  virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
T
Thunderbrook 已提交
51

Z
zhaocaibei123 已提交
52
  virtual ::std::future<int32_t> StopServer() override;
T
Thunderbrook 已提交
53

Z
zhaocaibei123 已提交
54 55 56
  virtual void FinalizeWorker() override {}
  virtual ::std::future<int32_t> PullDense(Region* regions, size_t region_num,
                                           size_t table_id);
T
Thunderbrook 已提交
57

Z
zhaocaibei123 已提交
58 59
  virtual ::std::future<int32_t> PushDense(const Region* regions,
                                           size_t region_num, size_t table_id);
Y
yaoxuefeng 已提交
60

Z
zhaocaibei123 已提交
61 62 63
  virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
                                                size_t region_num,
                                                size_t table_id);
Y
yaoxuefeng 已提交
64

Z
zhaocaibei123 已提交
65 66 67 68
  virtual ::std::future<int32_t> PullSparse(float** select_values,
                                            size_t table_id,
                                            const uint64_t* keys, size_t num,
                                            bool is_training) {
T
Thunderbrook 已提交
69 70 71 72 73 74 75
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
76 77 78 79
  virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
                                               size_t table_id,
                                               const uint64_t* keys,
                                               size_t num);
T
Thunderbrook 已提交
80

Z
zhaocaibei123 已提交
81
  virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
T
Thunderbrook 已提交
82 83 84 85 86 87
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }
Z
zhaocaibei123 已提交
88 89 90 91
  virtual ::std::future<int32_t> PushSparse(size_t table_id,
                                            const uint64_t* keys,
                                            const float** update_values,
                                            size_t num);
T
Thunderbrook 已提交
92

Z
zhaocaibei123 已提交
93
  virtual ::std::future<int32_t> Flush();
T
Thunderbrook 已提交
94
  // server profilera
Z
zhaocaibei123 已提交
95
  virtual std::future<int32_t> StartProfiler() {
T
Thunderbrook 已提交
96 97 98 99 100 101 102
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  };

Z
zhaocaibei123 已提交
103
  virtual std::future<int32_t> StopProfiler() {
T
Thunderbrook 已提交
104 105 106 107 108 109 110
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
111
  virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
T
Thunderbrook 已提交
112 113 114 115 116 117 118
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
119 120 121 122
  virtual std::future<int32_t> PullGeoParam(size_t table_id,
                                            std::vector<float>* values,
                                            std::vector<uint64_t>* keys,
                                            int pserver_idx) {
T
Thunderbrook 已提交
123 124 125 126 127 128 129
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
130 131 132
  virtual std::future<int32_t> PushGlobalStep(int table_id,
                                              int64_t* total_send_data,
                                              void* done) {
T
Thunderbrook 已提交
133 134 135 136 137 138 139 140
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

  // recv table from server and save it in LodTensor
Z
zhaocaibei123 已提交
141 142
  virtual int32_t RecvAndSaveTable(const uint64_t table_id,
                                   const std::string& path) {
T
Thunderbrook 已提交
143 144 145
    return 0;
  }

Z
zhaocaibei123 已提交
146
  virtual ::std::future<int32_t> SendClient2ClientMsg(
T
Thunderbrook 已提交
147 148 149 150 151 152 153
      int msg_type, int to_client_id, const std::string& msg) override {
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }
Z
zhaocaibei123 已提交
154
  virtual size_t GetServerNums() { return 1; }
T
Thunderbrook 已提交
155

Z
zhaocaibei123 已提交
156 157 158 159
  virtual std::future<int32_t> PushDenseRawGradient(int table_id,
                                                    float* total_send_data,
                                                    size_t total_send_data_size,
                                                    void* callback) override;
T
Thunderbrook 已提交
160

Z
zhaocaibei123 已提交
161
  virtual std::future<int32_t> PushSparseRawGradient(
T
Thunderbrook 已提交
162 163 164
      size_t table_id, const uint64_t* keys, const float** update_values,
      size_t num, void* callback) override;

Z
zhaocaibei123 已提交
165
  virtual std::future<int32_t> PushSparseRawGradientPartial(
T
Thunderbrook 已提交
166 167 168 169 170 171 172 173 174
      size_t table_id, const uint64_t* keys, const float** update_values,
      uint32_t num, void* done, int pserver_idx) override {
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

Z
zhaocaibei123 已提交
175 176 177 178 179
  virtual std::future<int32_t> PushSparseParam(size_t table_id,
                                               const uint64_t* keys,
                                               const float** update_values,
                                               size_t num,
                                               void* done) override {
T
Thunderbrook 已提交
180 181 182 183 184 185 186 187
    std::promise<int32_t> prom;
    std::future<int32_t> fut = prom.get_future();
    prom.set_value(0);

    return fut;
  }

 private:
Z
zhaocaibei123 已提交
188
  virtual int32_t Initialize() override;
T
Thunderbrook 已提交
189 190 191 192 193 194 195 196 197

  std::future<int32_t> done() {
    std::shared_ptr<std::promise<int32_t>> prom =
        std::make_shared<std::promise<int32_t>>();
    std::future<int32_t> fut = prom->get_future();
    prom->set_value(0);
    return fut;
  }

Z
zhaocaibei123 已提交
198 199
  inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
                                   uint32_t shard_num) {
T
Thunderbrook 已提交
200 201 202
    return dense_dim_total / shard_num + 1;
  }

Z
zhaocaibei123 已提交
203
  inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
T
Thunderbrook 已提交
204 205 206
    return &_table_map;
  }

Z
zhaocaibei123 已提交
207
  inline Table* GetTable(size_t table_id) {
T
Thunderbrook 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    auto itr = _table_map.find(table_id);
    if (itr != _table_map.end()) {
      return itr->second.get();
    }
    LOG(ERROR) << "table not found " << table_id;
    return NULL;
  }

  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;

  bool _running = false;
  bool _flushing = false;

 private:
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
};
}
}