ps_local_client.cc 10.8 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
T
Thunderbrook 已提交
17 18 19

namespace paddle {
namespace distributed {
Z
zhaocaibei123 已提交
20
int32_t PsLocalClient::Initialize() {
T
Thunderbrook 已提交
21
  const auto& downpour_param = _config.server_param().downpour_server_param();
Z
zhaocaibei123 已提交
22
  TableManager::Instance().Initialize();
Z
zhangchunle 已提交
23
  for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
T
Thunderbrook 已提交
24 25
    auto* table = CREATE_PSCORE_CLASS(
        Table, downpour_param.downpour_table_param(i).table_class());
Z
zhaocaibei123 已提交
26 27
    table->SetShard(0, 1);
    table->Initialize(downpour_param.downpour_table_param(i),
T
Thunderbrook 已提交
28 29 30 31 32 33
                      _config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }
  return 0;
}

Z
zhaocaibei123 已提交
34
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
T
Thunderbrook 已提交
35 36 37 38
                                             const std::string threshold) {
  return done();
}

Z
zhaocaibei123 已提交
39
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
T
Thunderbrook 已提交
40
                                           const std::string& mode) {
T
Thunderbrook 已提交
41
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
42
    Load(it.first, epoch, mode);
T
Thunderbrook 已提交
43
  }
T
Thunderbrook 已提交
44 45
  return done();
}
Z
zhaocaibei123 已提交
46
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
T
Thunderbrook 已提交
47 48
                                           const std::string& epoch,
                                           const std::string& mode) {
Z
zhaocaibei123 已提交
49 50
  auto* table_ptr = GetTable(table_id);
  table_ptr->Load(epoch, mode);
T
Thunderbrook 已提交
51 52 53
  return done();
}

Z
zhaocaibei123 已提交
54
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
T
Thunderbrook 已提交
55 56
                                           const std::string& mode) {
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
57
    Save(it.first, epoch, mode);
T
Thunderbrook 已提交
58 59 60
  }
  return done();
}
Z
zhaocaibei123 已提交
61
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
T
Thunderbrook 已提交
62 63
                                           const std::string& epoch,
                                           const std::string& mode) {
Z
zhaocaibei123 已提交
64 65 66
  auto* table_ptr = GetTable(table_id);
  table_ptr->Flush();
  table_ptr->Save(epoch, mode);
T
Thunderbrook 已提交
67 68 69
  return done();
}

70
::std::future<int32_t> PsLocalClient::Clear() { return done(); }
Z
zhaocaibei123 已提交
71
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
T
Thunderbrook 已提交
72 73 74
  return done();
}

Z
zhaocaibei123 已提交
75
::std::future<int32_t> PsLocalClient::Flush() {
T
Thunderbrook 已提交
76 77 78 79
  // no need
  return done();
}

Z
zhaocaibei123 已提交
80
::std::future<int32_t> PsLocalClient::StopServer() {
T
Thunderbrook 已提交
81 82 83 84
  // no need
  return done();
}

Z
zhaocaibei123 已提交
85 86 87 88 89
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
Y
yaoxuefeng 已提交
90

91 92
  uint32_t num_per_shard =
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
T
Thunderbrook 已提交
93 94 95

  std::vector<float> region_buffer;
  region_buffer.resize(num_per_shard);
96 97 98 99 100 101 102

  TableContext table_context;
  table_context.value_type = Dense;
  table_context.pull_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  table_ptr->Pull(table_context);
  //  table_ptr->PullDense(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
103 104 105 106 107 108

  size_t region_idx = 0;
  size_t region_data_idx = 0;
  size_t shard_data_size = num_per_shard;
  size_t shard_buffer_remain = shard_data_size * sizeof(float);
  PADDLE_ENFORCE_EQ(
109 110
      shard_buffer_remain,
      region_buffer.size() * sizeof(float),
T
Thunderbrook 已提交
111 112 113 114 115
      platform::errors::PreconditionNotMet("pull dense size error."));
  size_t index = 0;
  while (shard_buffer_remain > 0 && region_idx < region_num) {
    auto& region = regions[region_idx];
    if (region.size - region_data_idx >= shard_buffer_remain) {
116 117 118 119
      memcpy(reinterpret_cast<void*>(region.data + region_data_idx),
             reinterpret_cast<uint8_t*>(
                 reinterpret_cast<void*>(region_buffer.data())) +
                 index,
T
Thunderbrook 已提交
120 121 122 123 124 125 126
             shard_buffer_remain);
      region_data_idx += shard_buffer_remain;
      shard_buffer_remain = 0;
    } else if (region.size - region_data_idx == 0) {
      ++region_idx;
      region_data_idx = 0;
    } else {
127 128 129 130
      memcpy(reinterpret_cast<void*>(region.data + region_data_idx),
             reinterpret_cast<uint8_t*>(
                 reinterpret_cast<void*>(region_buffer.data())) +
                 index,
T
Thunderbrook 已提交
131 132 133 134 135 136 137 138 139 140 141
             region.size - region_data_idx);
      shard_buffer_remain -= (region.size - region_data_idx);
      index += (region.size - region_data_idx);
      ++region_idx;
      region_data_idx = 0;
    }
  }

  return done();
}

Z
zhaocaibei123 已提交
142 143 144 145 146
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
                                                     size_t region_num,
                                                     size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
147 148

  std::vector<float> region_buffer;
149 150
  region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
                       0);
T
Thunderbrook 已提交
151 152 153 154 155 156
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

157 158 159 160 161 162 163
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.push_context.is_param = true;
  table_context.num = region_buffer.size();

  table_ptr->Push(table_context);
Z
zhaocaibei123 已提交
164
  // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
165 166 167 168

  return done();
}

Z
zhaocaibei123 已提交
169
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
170 171 172
    int table_id,
    float* total_send_data,
    size_t total_send_data_size,
T
Thunderbrook 已提交
173 174 175 176 177
    void* callback) {
  VLOG(1) << "wxx push_dense_raw_gradient";

  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);

Z
zhaocaibei123 已提交
178
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
179

180 181 182 183 184 185 186
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = total_send_data;
  table_context.num = total_send_data_size;
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);

T
Thunderbrook 已提交
187 188 189 190
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
191 192 193 194 195
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
196 197

  std::vector<float> region_buffer;
198 199
  region_buffer.resize(
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
T
Thunderbrook 已提交
200 201 202 203
  size_t data_size = region_buffer.size();
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    PADDLE_ENFORCE_LE(
204 205
        offset + data_num,
        data_size,
T
Thunderbrook 已提交
206
        platform::errors::PreconditionNotMet(
207 208 209 210
            "invalid dense size, cur pos[%d] data_num[%d] size[%d]",
            offset,
            data_num,
            data_size));
T
Thunderbrook 已提交
211 212 213 214
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

215 216 217 218 219 220
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
221 222 223 224

  return done();
}

225 226 227 228 229 230 231 232
::std::future<int32_t> PsLocalClient::PullSparsePtr(
    int shard_id,
    char** select_values,
    size_t table_id,
    const uint64_t* keys,
    size_t num,
    uint16_t pass_id,
    const uint16_t& /**dim_id*/) {
T
Thunderbrook 已提交
233 234 235 236 237
  // FIXME
  // auto timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
  // auto local_timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
238
  // 将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
239
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
240

241 242 243 244 245 246
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.keys = keys;
  table_context.pull_context.ptr_values = select_values;
  table_context.use_ptr = true;
  table_context.num = num;
L
lxsbupt 已提交
247 248
  table_context.shard_id = shard_id;
  table_context.pass_id = pass_id;
249 250 251

  //  table_ptr->PullSparsePtr(select_values, keys, num);
  table_ptr->Pull(table_context);
T
Thunderbrook 已提交
252 253 254 255

  return done();
}

L
lxsbupt 已提交
256 257 258 259 260 261 262 263 264 265 266 267 268
::std::future<int32_t> PsLocalClient::PrintTableStat(uint32_t table_id) {
  auto* table_ptr = GetTable(table_id);
  std::pair<int64_t, int64_t> ret = table_ptr->PrintTableStat();
  VLOG(0) << "table id: " << table_id << ", feasign size: " << ret.first
          << ", mf size: " << ret.second;
  return done();
}

::std::future<int32_t> PsLocalClient::SaveCacheTable(uint32_t table_id,
                                                     uint16_t pass_id,
                                                     size_t threshold) {
  auto* table_ptr = GetTable(table_id);
  std::pair<int64_t, int64_t> ret = table_ptr->PrintTableStat();
269
  VLOG(1) << "table id: " << table_id << ", feasign size: " << ret.first
L
lxsbupt 已提交
270 271
          << ", mf size: " << ret.second;
  if (ret.first > (int64_t)threshold) {
272
    VLOG(1) << "run cache table";
L
lxsbupt 已提交
273 274 275 276 277
    table_ptr->CacheTable(pass_id);
  }
  return done();
}

Z
zhaocaibei123 已提交
278
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
279 280 281 282 283
    size_t table_id,
    const uint64_t* keys,
    const float** update_values,
    size_t num,
    void* callback) {
T
Thunderbrook 已提交
284
  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
Z
zhaocaibei123 已提交
285
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
286

287 288 289 290 291 292 293 294 295
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  // table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
296 297 298 299
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
300 301 302 303 304
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
                                                 const uint64_t* keys,
                                                 const float** update_values,
                                                 size_t num) {
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
305

306 307 308 309 310 311 312 313 314
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  //  table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
315 316
  return done();
}
317 318
}  // namespace distributed
}  // namespace paddle