ps_local_client.cc 12.3 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
16

17
#include "paddle/fluid/distributed/ps/table/table.h"
T
Thunderbrook 已提交
18

19
// #define pslib_debug_dense_compress
T
Thunderbrook 已提交
20 21 22

namespace paddle {
namespace distributed {
Z
zhaocaibei123 已提交
23
int32_t PsLocalClient::Initialize() {
T
Thunderbrook 已提交
24
  const auto& downpour_param = _config.server_param().downpour_server_param();
Z
zhaocaibei123 已提交
25
  TableManager::Instance().Initialize();
Z
zhangchunle 已提交
26
  for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
T
Thunderbrook 已提交
27 28
    auto* table = CREATE_PSCORE_CLASS(
        Table, downpour_param.downpour_table_param(i).table_class());
Z
zhaocaibei123 已提交
29 30
    table->SetShard(0, 1);
    table->Initialize(downpour_param.downpour_table_param(i),
T
Thunderbrook 已提交
31 32 33 34 35 36
                      _config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }
  return 0;
}

Z
zhaocaibei123 已提交
37
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
T
Thunderbrook 已提交
38
                                             const std::string threshold) {
39
  // TODO  // NOLINT
T
Thunderbrook 已提交
40 41 42
  return done();
}

Z
zhaocaibei123 已提交
43
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
T
Thunderbrook 已提交
44
                                           const std::string& mode) {
45
  // TODO  // NOLINT
T
Thunderbrook 已提交
46
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
47
    Load(it.first, epoch, mode);
T
Thunderbrook 已提交
48
  }
T
Thunderbrook 已提交
49 50
  return done();
}
Z
zhaocaibei123 已提交
51
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
T
Thunderbrook 已提交
52 53
                                           const std::string& epoch,
                                           const std::string& mode) {
54
  // TODO  // NOLINT
Z
zhaocaibei123 已提交
55 56
  auto* table_ptr = GetTable(table_id);
  table_ptr->Load(epoch, mode);
T
Thunderbrook 已提交
57 58 59
  return done();
}

Z
zhaocaibei123 已提交
60
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
T
Thunderbrook 已提交
61
                                           const std::string& mode) {
62
  // TODO  // NOLINT
T
Thunderbrook 已提交
63
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
64
    Save(it.first, epoch, mode);
T
Thunderbrook 已提交
65 66 67
  }
  return done();
}
Z
zhaocaibei123 已提交
68
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
T
Thunderbrook 已提交
69 70
                                           const std::string& epoch,
                                           const std::string& mode) {
71
  // TODO  // NOLINT
Z
zhaocaibei123 已提交
72 73 74
  auto* table_ptr = GetTable(table_id);
  table_ptr->Flush();
  table_ptr->Save(epoch, mode);
T
Thunderbrook 已提交
75 76 77
  return done();
}

Z
zhaocaibei123 已提交
78
::std::future<int32_t> PsLocalClient::Clear() {
79
  // TODO  // NOLINT
T
Thunderbrook 已提交
80 81
  return done();
}
Z
zhaocaibei123 已提交
82
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
83
  // TODO  // NOLINT
T
Thunderbrook 已提交
84 85 86
  return done();
}

Z
zhaocaibei123 已提交
87
::std::future<int32_t> PsLocalClient::Flush() {
T
Thunderbrook 已提交
88 89 90 91
  // no need
  return done();
}

Z
zhaocaibei123 已提交
92
::std::future<int32_t> PsLocalClient::StopServer() {
T
Thunderbrook 已提交
93 94 95 96
  // no need
  return done();
}

Z
zhaocaibei123 已提交
97 98 99 100 101
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
Y
yaoxuefeng 已提交
102

103 104
  uint32_t num_per_shard =
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
T
Thunderbrook 已提交
105 106 107

  std::vector<float> region_buffer;
  region_buffer.resize(num_per_shard);
108 109 110 111 112 113 114

  TableContext table_context;
  table_context.value_type = Dense;
  table_context.pull_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  table_ptr->Pull(table_context);
  //  table_ptr->PullDense(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
115 116 117 118 119 120

  size_t region_idx = 0;
  size_t region_data_idx = 0;
  size_t shard_data_size = num_per_shard;
  size_t shard_buffer_remain = shard_data_size * sizeof(float);
  PADDLE_ENFORCE_EQ(
121 122
      shard_buffer_remain,
      region_buffer.size() * sizeof(float),
T
Thunderbrook 已提交
123 124 125 126 127
      platform::errors::PreconditionNotMet("pull dense size error."));
  size_t index = 0;
  while (shard_buffer_remain > 0 && region_idx < region_num) {
    auto& region = regions[region_idx];
    if (region.size - region_data_idx >= shard_buffer_remain) {
128 129 130 131
      memcpy(reinterpret_cast<void*>(region.data + region_data_idx),
             reinterpret_cast<uint8_t*>(
                 reinterpret_cast<void*>(region_buffer.data())) +
                 index,
T
Thunderbrook 已提交
132 133 134 135 136 137 138
             shard_buffer_remain);
      region_data_idx += shard_buffer_remain;
      shard_buffer_remain = 0;
    } else if (region.size - region_data_idx == 0) {
      ++region_idx;
      region_data_idx = 0;
    } else {
139 140 141 142
      memcpy(reinterpret_cast<void*>(region.data + region_data_idx),
             reinterpret_cast<uint8_t*>(
                 reinterpret_cast<void*>(region_buffer.data())) +
                 index,
T
Thunderbrook 已提交
143 144 145 146 147 148 149 150 151 152 153
             region.size - region_data_idx);
      shard_buffer_remain -= (region.size - region_data_idx);
      index += (region.size - region_data_idx);
      ++region_idx;
      region_data_idx = 0;
    }
  }

  return done();
}

Z
zhaocaibei123 已提交
154 155 156 157 158
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
                                                     size_t region_num,
                                                     size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
159 160

  std::vector<float> region_buffer;
161 162
  region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
                       0);
T
Thunderbrook 已提交
163 164 165 166 167 168
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

169 170 171 172 173 174 175
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.push_context.is_param = true;
  table_context.num = region_buffer.size();

  table_ptr->Push(table_context);
Z
zhaocaibei123 已提交
176
  // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
177 178 179 180

  return done();
}

Z
zhaocaibei123 已提交
181
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
182 183 184
    int table_id,
    float* total_send_data,
    size_t total_send_data_size,
T
Thunderbrook 已提交
185 186 187 188 189
    void* callback) {
  VLOG(1) << "wxx push_dense_raw_gradient";

  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);

Z
zhaocaibei123 已提交
190
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
191

192 193 194 195 196 197 198
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = total_send_data;
  table_context.num = total_send_data_size;
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);

T
Thunderbrook 已提交
199 200 201 202
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
203 204 205 206 207
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
208 209

  std::vector<float> region_buffer;
210 211
  region_buffer.resize(
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
T
Thunderbrook 已提交
212 213 214 215
  size_t data_size = region_buffer.size();
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    PADDLE_ENFORCE_LE(
216 217
        offset + data_num,
        data_size,
T
Thunderbrook 已提交
218
        platform::errors::PreconditionNotMet(
219 220 221 222
            "invalid dense size, cur pos[%d] data_num[%d] size[%d]",
            offset,
            data_num,
            data_size));
T
Thunderbrook 已提交
223 224 225 226
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

227 228 229 230 231 232
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
233 234 235 236

  return done();
}

237
// ::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
T
Thunderbrook 已提交
238 239 240 241 242 243 244 245 246
//                                                  size_t table_id,
//                                                  const uint64_t* keys,
//                                                  size_t num) {
//  // FIXME
//  // auto timer =
//  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
//  // auto local_timer =
//  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//  //将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
247 248
//  auto* accessor = GetTableAccessor(table_id);
//  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
249 250
//  size_t value_size = accessor->select_size();
//
Z
zhaocaibei123 已提交
251
//  // table_ptr->PullSparse(keys, num);
T
Thunderbrook 已提交
252 253
//  std::vector<float> res_data;
//  res_data.resize(num * value_size / sizeof(float));
Z
zhaocaibei123 已提交
254
//  table_ptr->PullSparse(res_data.data(), keys, num);
T
Thunderbrook 已提交
255 256 257 258 259 260 261 262 263 264 265 266
//  // memcpy(select_values[0], res_data->data(), res_data->size() *
//  // sizeof(float));
//  size_t offset = 0;
//  for (int i = 0; i < num; ++i) {
//    memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
//    offset += value_size;
//  }
//
//  // return fut;
//  return done();
//}

L
lxsbupt 已提交
267 268
::std::future<int32_t> PsLocalClient::PullSparsePtr(int shard_id,
                                                    char** select_values,
Z
zhaocaibei123 已提交
269 270
                                                    size_t table_id,
                                                    const uint64_t* keys,
L
lxsbupt 已提交
271 272
                                                    size_t num,
                                                    uint16_t pass_id) {
T
Thunderbrook 已提交
273 274 275 276 277
  // FIXME
  // auto timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
  // auto local_timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
278
  // 将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
279
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
280

281 282 283 284 285 286
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.keys = keys;
  table_context.pull_context.ptr_values = select_values;
  table_context.use_ptr = true;
  table_context.num = num;
L
lxsbupt 已提交
287 288
  table_context.shard_id = shard_id;
  table_context.pass_id = pass_id;
289 290 291

  //  table_ptr->PullSparsePtr(select_values, keys, num);
  table_ptr->Pull(table_context);
T
Thunderbrook 已提交
292 293 294 295

  return done();
}

L
lxsbupt 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
::std::future<int32_t> PsLocalClient::PrintTableStat(uint32_t table_id) {
  auto* table_ptr = GetTable(table_id);
  std::pair<int64_t, int64_t> ret = table_ptr->PrintTableStat();
  VLOG(0) << "table id: " << table_id << ", feasign size: " << ret.first
          << ", mf size: " << ret.second;
  return done();
}

::std::future<int32_t> PsLocalClient::SaveCacheTable(uint32_t table_id,
                                                     uint16_t pass_id,
                                                     size_t threshold) {
  auto* table_ptr = GetTable(table_id);
  std::pair<int64_t, int64_t> ret = table_ptr->PrintTableStat();
  VLOG(0) << "table id: " << table_id << ", feasign size: " << ret.first
          << ", mf size: " << ret.second;
  if (ret.first > (int64_t)threshold) {
    VLOG(0) << "run cache table";
    table_ptr->CacheTable(pass_id);
  }
  return done();
}

Z
zhaocaibei123 已提交
318
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
319 320 321 322 323
    size_t table_id,
    const uint64_t* keys,
    const float** update_values,
    size_t num,
    void* callback) {
T
Thunderbrook 已提交
324
  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
Z
zhaocaibei123 已提交
325
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
326

327 328 329 330 331 332 333 334 335
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  // table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
336 337 338 339
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
340 341 342 343 344
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
                                                 const uint64_t* keys,
                                                 const float** update_values,
                                                 size_t num) {
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
345

346 347 348 349 350 351 352 353 354
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  //  table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
355 356
  return done();
}
357 358
}  // namespace distributed
}  // namespace paddle