ps_local_client.cc 11.1 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
16

17
#include "paddle/fluid/distributed/ps/table/table.h"
T
Thunderbrook 已提交
18 19 20 21 22

//#define pslib_debug_dense_compress

namespace paddle {
namespace distributed {
Z
zhaocaibei123 已提交
23
int32_t PsLocalClient::Initialize() {
T
Thunderbrook 已提交
24
  const auto& downpour_param = _config.server_param().downpour_server_param();
Z
zhaocaibei123 已提交
25
  TableManager::Instance().Initialize();
Z
zhangchunle 已提交
26
  for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
T
Thunderbrook 已提交
27 28
    auto* table = CREATE_PSCORE_CLASS(
        Table, downpour_param.downpour_table_param(i).table_class());
Z
zhaocaibei123 已提交
29 30
    table->SetShard(0, 1);
    table->Initialize(downpour_param.downpour_table_param(i),
T
Thunderbrook 已提交
31 32 33 34 35 36
                      _config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }
  return 0;
}

Z
zhaocaibei123 已提交
37
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
T
Thunderbrook 已提交
38 39 40 41 42
                                             const std::string threshold) {
  // TODO
  return done();
}

Z
zhaocaibei123 已提交
43
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
T
Thunderbrook 已提交
44 45
                                           const std::string& mode) {
  // TODO
T
Thunderbrook 已提交
46
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
47
    Load(it.first, epoch, mode);
T
Thunderbrook 已提交
48
  }
T
Thunderbrook 已提交
49 50
  return done();
}
Z
zhaocaibei123 已提交
51
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
T
Thunderbrook 已提交
52 53 54
                                           const std::string& epoch,
                                           const std::string& mode) {
  // TODO
Z
zhaocaibei123 已提交
55 56
  auto* table_ptr = GetTable(table_id);
  table_ptr->Load(epoch, mode);
T
Thunderbrook 已提交
57 58 59
  return done();
}

Z
zhaocaibei123 已提交
60
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
T
Thunderbrook 已提交
61 62 63
                                           const std::string& mode) {
  // TODO
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
64
    Save(it.first, epoch, mode);
T
Thunderbrook 已提交
65 66 67
  }
  return done();
}
Z
zhaocaibei123 已提交
68
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
T
Thunderbrook 已提交
69 70 71
                                           const std::string& epoch,
                                           const std::string& mode) {
  // TODO
Z
zhaocaibei123 已提交
72 73 74
  auto* table_ptr = GetTable(table_id);
  table_ptr->Flush();
  table_ptr->Save(epoch, mode);
T
Thunderbrook 已提交
75 76 77
  return done();
}

Z
zhaocaibei123 已提交
78
::std::future<int32_t> PsLocalClient::Clear() {
T
Thunderbrook 已提交
79 80 81
  // TODO
  return done();
}
Z
zhaocaibei123 已提交
82
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
T
Thunderbrook 已提交
83 84 85 86
  // TODO
  return done();
}

Z
zhaocaibei123 已提交
87
::std::future<int32_t> PsLocalClient::Flush() {
T
Thunderbrook 已提交
88 89 90 91
  // no need
  return done();
}

Z
zhaocaibei123 已提交
92
::std::future<int32_t> PsLocalClient::StopServer() {
T
Thunderbrook 已提交
93 94 95 96
  // no need
  return done();
}

Z
zhaocaibei123 已提交
97 98 99 100 101
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
Y
yaoxuefeng 已提交
102

103 104
  uint32_t num_per_shard =
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
T
Thunderbrook 已提交
105 106 107

  std::vector<float> region_buffer;
  region_buffer.resize(num_per_shard);
108 109 110 111 112 113 114

  TableContext table_context;
  table_context.value_type = Dense;
  table_context.pull_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  table_ptr->Pull(table_context);
  //  table_ptr->PullDense(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

  size_t region_idx = 0;
  size_t region_data_idx = 0;
  size_t shard_data_size = num_per_shard;
  size_t shard_buffer_remain = shard_data_size * sizeof(float);
  PADDLE_ENFORCE_EQ(
      shard_buffer_remain, region_buffer.size() * sizeof(float),
      platform::errors::PreconditionNotMet("pull dense size error."));
  size_t index = 0;
  while (shard_buffer_remain > 0 && region_idx < region_num) {
    auto& region = regions[region_idx];
    if (region.size - region_data_idx >= shard_buffer_remain) {
      memcpy((void*)(region.data + region_data_idx),
             (uint8_t*)(void*)(region_buffer.data()) + index,
             shard_buffer_remain);
      region_data_idx += shard_buffer_remain;
      shard_buffer_remain = 0;
    } else if (region.size - region_data_idx == 0) {
      ++region_idx;
      region_data_idx = 0;
    } else {
      memcpy((void*)(region.data + region_data_idx),
             (uint8_t*)(void*)(region_buffer.data()) + index,
             region.size - region_data_idx);
      shard_buffer_remain -= (region.size - region_data_idx);
      index += (region.size - region_data_idx);
      ++region_idx;
      region_data_idx = 0;
    }
  }

  return done();
}

Z
zhaocaibei123 已提交
149 150 151 152 153
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
                                                     size_t region_num,
                                                     size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
154 155

  std::vector<float> region_buffer;
156 157
  region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
                       0);
T
Thunderbrook 已提交
158 159 160 161 162 163
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

164 165 166 167 168 169 170
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.push_context.is_param = true;
  table_context.num = region_buffer.size();

  table_ptr->Push(table_context);
Z
zhaocaibei123 已提交
171
  // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
172 173 174 175

  return done();
}

Z
zhaocaibei123 已提交
176
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
T
Thunderbrook 已提交
177 178 179 180 181 182
    int table_id, float* total_send_data, size_t total_send_data_size,
    void* callback) {
  VLOG(1) << "wxx push_dense_raw_gradient";

  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);

Z
zhaocaibei123 已提交
183
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
184

185 186 187 188 189 190 191
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = total_send_data;
  table_context.num = total_send_data_size;
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);

T
Thunderbrook 已提交
192 193 194 195
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
196 197 198 199 200
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
201 202

  std::vector<float> region_buffer;
203 204
  region_buffer.resize(
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
T
Thunderbrook 已提交
205 206 207 208 209 210 211 212 213 214 215 216
  size_t data_size = region_buffer.size();
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    PADDLE_ENFORCE_LE(
        offset + data_num, data_size,
        platform::errors::PreconditionNotMet(
            "invalid dense size, cur pos[%d] data_num[%d] size[%d]", offset,
            data_num, data_size));
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

217 218 219 220 221 222
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
223 224 225 226

  return done();
}

Z
zhaocaibei123 已提交
227
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
T
Thunderbrook 已提交
228 229 230 231 232 233 234 235 236
//                                                  size_t table_id,
//                                                  const uint64_t* keys,
//                                                  size_t num) {
//  // FIXME
//  // auto timer =
//  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
//  // auto local_timer =
//  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//  //将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
237 238
//  auto* accessor = GetTableAccessor(table_id);
//  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
239 240
//  size_t value_size = accessor->select_size();
//
Z
zhaocaibei123 已提交
241
//  // table_ptr->PullSparse(keys, num);
T
Thunderbrook 已提交
242 243
//  std::vector<float> res_data;
//  res_data.resize(num * value_size / sizeof(float));
Z
zhaocaibei123 已提交
244
//  table_ptr->PullSparse(res_data.data(), keys, num);
T
Thunderbrook 已提交
245 246 247 248 249 250 251 252 253 254 255 256
//  // memcpy(select_values[0], res_data->data(), res_data->size() *
//  // sizeof(float));
//  size_t offset = 0;
//  for (int i = 0; i < num; ++i) {
//    memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
//    offset += value_size;
//  }
//
//  // return fut;
//  return done();
//}

Z
zhaocaibei123 已提交
257 258 259 260
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
                                                    size_t table_id,
                                                    const uint64_t* keys,
                                                    size_t num) {
T
Thunderbrook 已提交
261 262 263 264 265 266
  // FIXME
  // auto timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
  // auto local_timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
  //将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
267
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
268

269 270 271 272 273 274 275 276 277
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.keys = keys;
  table_context.pull_context.ptr_values = select_values;
  table_context.use_ptr = true;
  table_context.num = num;

  //  table_ptr->PullSparsePtr(select_values, keys, num);
  table_ptr->Pull(table_context);
T
Thunderbrook 已提交
278 279 280 281

  return done();
}

Z
zhaocaibei123 已提交
282
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
T
Thunderbrook 已提交
283 284 285
    size_t table_id, const uint64_t* keys, const float** update_values,
    size_t num, void* callback) {
  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
Z
zhaocaibei123 已提交
286 287
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
288

289 290 291 292 293 294 295 296 297
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  // table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
298 299 300 301
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
302 303 304 305 306 307
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
                                                 const uint64_t* keys,
                                                 const float** update_values,
                                                 size_t num) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
308

309 310 311 312 313 314 315 316 317
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  //  table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
318 319
  return done();
}
320 321
}  // namespace distributed
}  // namespace paddle