memory_sparse_table.cc 28.0 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"

17
#include <omp.h>
Z
zhaocaibei123 已提交
18

19
#include <sstream>
Z
zhaocaibei123 已提交
20 21 22

#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
23 24
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/framework/io/fs.h"
Z
zhaocaibei123 已提交
25 26
#include "paddle/fluid/platform/enforce.h"

27 28
DEFINE_bool(pserver_print_missed_key_num_every_push,
            false,
Z
zhaocaibei123 已提交
29
            "pserver_print_missed_key_num_every_push");
30 31
DEFINE_bool(pserver_create_value_when_push,
            true,
Z
zhaocaibei123 已提交
32
            "pserver create value when push");
33 34
DEFINE_bool(pserver_enable_create_feasign_randomly,
            false,
Z
zhaocaibei123 已提交
35 36 37
            "pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");

Z
zhaocaibei123 已提交
38 39 40
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
41
int32_t MemorySparseTable::Initialize() {
42
  _shards_task_pool.resize(_task_pool_size);
43
  for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
44
    _shards_task_pool[i].reset(new ::ThreadPool(1));
Z
zhaocaibei123 已提交
45
  }
46 47 48
  auto& profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_sparse_update_all");
  profiler.register_profiler("pserver_sparse_select_all");
Z
zhaocaibei123 已提交
49
  InitializeValue();
Z
zhaocaibei123 已提交
50 51 52 53
  VLOG(0) << "initalize MemorySparseTable succ";
  return 0;
}

Z
zhaocaibei123 已提交
54
int32_t MemorySparseTable::InitializeValue() {
55 56
  _sparse_table_shard_num = static_cast<int>(_config.shard_num());
  _avg_local_shard_num =
57
      sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
58
  _real_local_shard_num = _avg_local_shard_num;
Z
zhangchunle 已提交
59 60
  if (static_cast<int>(_real_local_shard_num * (_shard_idx + 1)) >
      _sparse_table_shard_num) {
61 62 63 64
    _real_local_shard_num =
        _sparse_table_shard_num - _real_local_shard_num * _shard_idx;
    _real_local_shard_num =
        _real_local_shard_num < 0 ? 0 : _real_local_shard_num;
Z
zhaocaibei123 已提交
65
  }
66 67 68
  VLOG(1) << "memory sparse table _avg_local_shard_num: "
          << _avg_local_shard_num
          << " _real_local_shard_num: " << _real_local_shard_num;
Z
zhaocaibei123 已提交
69

70
  _local_shards.reset(new shard_type[_real_local_shard_num]);
Z
zhaocaibei123 已提交
71 72 73 74

  return 0;
}

Z
zhaocaibei123 已提交
75
int32_t MemorySparseTable::Load(const std::string& path,
Z
zhaocaibei123 已提交
76
                                const std::string& param) {
Z
zhaocaibei123 已提交
77
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
78 79 80 81
  auto file_list = _afs_client.list(table_path);

  std::sort(file_list.begin(), file_list.end());
  for (auto file : file_list) {
Z
zhaocaibei123 已提交
82
    VLOG(1) << "MemorySparseTable::Load() file list: " << file;
Z
zhaocaibei123 已提交
83 84 85
  }

  int load_param = atoi(param.c_str());
86
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
87 88 89 90 91 92 93 94 95 96
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

97
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
98

99
  size_t feature_value_size =
100
      _value_accesor->GetAccessorInfo().size / sizeof(float);
101 102 103 104

  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
105
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
106 107 108 109
    FsChannelConfig channel_config;
    channel_config.path = file_list[file_start_idx + i];
    VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
            << " into local shard " << i;
110
    channel_config.converter = _value_accesor->Converter(load_param).converter;
Z
zhaocaibei123 已提交
111
    channel_config.deconverter =
112
        _value_accesor->Converter(load_param).deconverter;
Z
zhaocaibei123 已提交
113 114 115 116 117 118 119 120 121 122

    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
      char* end = NULL;
123
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
124 125 126 127
      try {
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
128 129
          auto& value = shard[key];
          value.resize(feature_value_size);
130
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
131
          value.resize(parse_size);
Z
zhaocaibei123 已提交
132 133 134 135

          // for debug
          for (int ii = 0; ii < parse_size; ++ii) {
            VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
136
                    << ": " << value.data()[ii] << " local_shard: " << i;
Z
zhaocaibei123 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
          }
        }
        read_channel->close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << channel_config.path << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << channel_config.path << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
153
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
154 155 156 157 158 159 160
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
161
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
162 163 164
  return 0;
}

Z
zhaocaibei123 已提交
165 166 167
int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
                                       const std::string& param) {
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
168
  auto file_list = paddle::framework::localfs_list(table_path);
169
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
170 171 172 173 174 175 176 177 178 179
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

180
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
181

182
  size_t feature_value_size =
183
      _value_accesor->GetAccessorInfo().size / sizeof(float);
Z
zhaocaibei123 已提交
184

185 186 187
  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
188
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
189 190 191 192 193 194 195 196 197
    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      std::ifstream file(file_list[file_start_idx + i]);
      char* end = NULL;
198
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
199 200 201
      try {
        while (std::getline(file, line_data) && line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
202 203
          auto& value = shard[key];
          value.resize(feature_value_size);
204
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
205
          value.resize(parse_size);
Z
zhaocaibei123 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
        }
        file.close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << file_list[file_start_idx + i] << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << file_list[file_start_idx + i]
                   << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
222
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
223 224 225 226 227 228 229
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
230
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
231 232 233
  return 0;
}

Z
zhaocaibei123 已提交
234
int32_t MemorySparseTable::Save(const std::string& dirname,
Z
zhaocaibei123 已提交
235 236 237 238
                                const std::string& param) {
  VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
239
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
240 241 242 243
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  std::atomic<uint32_t> feasign_size_all{0};

244
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;
Z
zhaocaibei123 已提交
245

246 247 248
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
249
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
250 251 252
    FsChannelConfig channel_config;
    if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
      channel_config.path =
253 254 255 256 257 258 259 260 261
          paddle::string::format_string("%s/part-%03d-%05d.gz",
                                        table_path.c_str(),
                                        _shard_idx,
                                        file_start_idx + i);
    } else {
      channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
                                                          table_path.c_str(),
                                                          _shard_idx,
                                                          file_start_idx + i);
Z
zhaocaibei123 已提交
262
    }
263
    channel_config.converter = _value_accesor->Converter(save_param).converter;
Z
zhaocaibei123 已提交
264
    channel_config.deconverter =
265
        _value_accesor->Converter(save_param).deconverter;
Z
zhaocaibei123 已提交
266 267 268 269
    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
270
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
271 272 273 274 275 276
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
277
      for (auto it = shard.begin(); it != shard.end(); ++it) {
278 279
        if (_value_accesor->Save(it.value().data(), save_param)) {
          std::string format_value = _value_accesor->ParseToString(
280
              it.value().data(), it.value().size());
281 282
          if (0 != write_channel->write_line(paddle::string::format_string(
                       "%lu %s", it.key(), format_value.c_str()))) {
283 284 285 286 287 288
            ++retry_num;
            is_write_failed = true;
            LOG(ERROR)
                << "MemorySparseTable save prefix failed, retry it! path:"
                << channel_config.path << " , retry_num=" << retry_num;
            break;
Z
zhaocaibei123 已提交
289
          }
290
          ++feasign_size;
Z
zhaocaibei123 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303
        }
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save prefix failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
Z
zhaocaibei123 已提交
304
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
305 306 307 308 309
        LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
310
    for (auto it = shard.begin(); it != shard.end(); ++it) {
311
      _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
Z
zhaocaibei123 已提交
312 313 314 315 316 317 318 319
    }
    LOG(INFO) << "MemorySparseTable save prefix success, path: "
              << channel_config.path;
  }
  // int32 may overflow need to change return value
  return 0;
}

Z
zhaocaibei123 已提交
320 321 322
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
                                       const std::string& param,
                                       const std::string& prefix) {
Z
zhaocaibei123 已提交
323 324
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
325
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
326
  int feasign_cnt = 0;
327 328 329 330 331 332 333
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;

  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  std::atomic<uint32_t> feasign_size_all{0};

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
334
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
335
    feasign_cnt = 0;
336
    auto& shard = _local_shards[i];
337 338 339 340 341 342
    std::string file_name =
        paddle::string::format_string("%s/part-%s-%03d-%05d",
                                      table_path.c_str(),
                                      prefix.c_str(),
                                      _shard_idx,
                                      file_start_idx + i);
Z
zhaocaibei123 已提交
343 344
    std::ofstream os;
    os.open(file_name);
345
    for (auto it = shard.begin(); it != shard.end(); ++it) {
346 347 348
      if (_value_accesor->Save(it.value().data(), save_param)) {
        std::string format_value =
            _value_accesor->ParseToString(it.value().data(), it.value().size());
349 350 351 352 353
        std::string out_line = paddle::string::format_string(
            "%lu %s\n", it.key(), format_value.c_str());
        // VLOG(2) << out_line.c_str();
        os.write(out_line.c_str(), sizeof(char) * out_line.size());
        ++feasign_cnt;
Z
zhaocaibei123 已提交
354 355 356 357 358 359 360 361 362
      }
    }
    os.close();
    LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
              << "feasign_cnt: " << feasign_cnt;
  }
  return 0;
}

Z
zhaocaibei123 已提交
363
int64_t MemorySparseTable::LocalSize() {
364
  int64_t local_size = 0;
365
  for (int i = 0; i < _real_local_shard_num; ++i) {
366 367 368 369
    local_size += _local_shards[i].size();
  }
  return local_size;
}
Z
zhaocaibei123 已提交
370

Z
zhaocaibei123 已提交
371
int64_t MemorySparseTable::LocalMFSize() {
372 373 374
  std::vector<int64_t> size_arr(_real_local_shard_num, 0);
  std::vector<std::future<int>> tasks(_real_local_shard_num);
  int64_t ret_size = 0;
375
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
376 377 378 379 380 381
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &size_arr]() -> int {
              auto& local_shard = _local_shards[shard_id];
              for (auto it = local_shard.begin(); it != local_shard.end();
                   ++it) {
382
                if (_value_accesor->HasMF(it.value().size())) {
383 384 385 386 387 388
                  size_arr[shard_id] += 1;
                }
              }
              return 0;
            });
  }
389
  for (int i = 0; i < _real_local_shard_num; ++i) {
390 391 392 393
    tasks[i].wait();
  }
  for (auto x : size_arr) {
    ret_size += x;
Z
zhaocaibei123 已提交
394
  }
395 396
  return ret_size;
}
Z
zhaocaibei123 已提交
397

Z
zhaocaibei123 已提交
398 399 400
std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
  int64_t feasign_size = LocalSize();
  int64_t mf_size = LocalMFSize();
Z
zhaocaibei123 已提交
401 402 403
  return {feasign_size, mf_size};
}

Y
yaoxuefeng 已提交
404 405 406 407 408
int32_t MemorySparseTable::Pull(TableContext& context) {
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
    char** pull_values = context.pull_context.ptr_values;
    const uint64_t* keys = context.pull_context.keys;
Z
zhaocaibei123 已提交
409
    return PullSparsePtr(pull_values, keys, context.num);
Y
yaoxuefeng 已提交
410 411 412
  } else {
    float* pull_values = context.pull_context.values;
    const PullSparseValue& pull_value = context.pull_context.pull_value;
Z
zhaocaibei123 已提交
413
    return PullSparse(pull_values, pull_value);
Y
yaoxuefeng 已提交
414 415 416 417 418
  }
}

int32_t MemorySparseTable::Push(TableContext& context) {
  CHECK(context.value_type == Sparse);
419
  if (!context.use_ptr) {
420 421
    return PushSparse(
        context.push_context.keys, context.push_context.values, context.num);
422 423
  } else {
    return PushSparse(context.push_context.keys,
424 425
                      context.push_context.ptr_values,
                      context.num);
426
  }
Y
yaoxuefeng 已提交
427 428
}

Z
zhaocaibei123 已提交
429 430
int32_t MemorySparseTable::PullSparse(float* pull_values,
                                      const PullSparseValue& pull_value) {
431 432
  CostTimer timer("pserver_sparse_select_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
433

434 435 436 437
  const size_t value_size =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
438
  size_t select_value_size =
439
      _value_accesor->GetAccessorInfo().select_size / sizeof(float);
Z
zhaocaibei123 已提交
440 441 442
  // std::atomic<uint32_t> missed_keys{0};

  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
443
      _real_local_shard_num);
Z
zhaocaibei123 已提交
444 445
  size_t num = pull_value.numel_;
  for (size_t i = 0; i < num; ++i) {
446 447
    int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) %
                   _avg_local_shard_num;
Z
zhaocaibei123 已提交
448 449
    task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
  }
450
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
451
    tasks[shard_id] =
452
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
453 454 455 456 457 458
            [this,
             shard_id,
             &task_keys,
             value_size,
             pull_values,
             mf_value_size,
Z
zhaocaibei123 已提交
459
             select_value_size]() -> int {
460
              auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
461 462 463 464 465 466
              float data_buffer[value_size];  // NOLINT
              float* data_buffer_ptr = data_buffer;

              auto& keys = task_keys[shard_id];
              for (size_t i = 0; i < keys.size(); i++) {
                uint64_t key = keys[i].first;
467
                auto itr = local_shard.find(key);
Z
zhaocaibei123 已提交
468
                size_t data_size = value_size - mf_value_size;
469
                if (itr == local_shard.end()) {
Z
zhaocaibei123 已提交
470
                  // ++missed_keys;
471
                  if (FLAGS_pserver_create_value_when_push) {
Z
zhaocaibei123 已提交
472 473
                    memset(data_buffer, 0, sizeof(float) * data_size);
                  } else {
474 475 476
                    auto& feature_value = local_shard[key];
                    feature_value.resize(data_size);
                    float* data_ptr = feature_value.data();
477
                    _value_accesor->Create(&data_buffer_ptr, 1);
478 479
                    memcpy(
                        data_ptr, data_buffer_ptr, data_size * sizeof(float));
Z
zhaocaibei123 已提交
480 481
                  }
                } else {
482
                  data_size = itr.value().size();
483 484
                  memcpy(data_buffer_ptr,
                         itr.value().data(),
Z
zhaocaibei123 已提交
485 486
                         data_size * sizeof(float));
                }
487
                for (size_t mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
Z
zhaocaibei123 已提交
488 489 490 491
                  data_buffer[mf_idx] = 0.0;
                }
                auto offset = keys[i].second;
                float* select_data = pull_values + select_value_size * offset;
492 493
                _value_accesor->Select(
                    &select_data, (const float**)&data_buffer_ptr, 1);
Z
zhaocaibei123 已提交
494 495 496 497 498 499 500 501 502 503 504 505
              }

              return 0;
            });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
506
int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
507 508
                                         const uint64_t* keys,
                                         size_t num) {
509
  CostTimer timer("pscore_sparse_select_all");
510 511 512
  size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
513 514 515 516 517 518 519 520 521

  std::vector<std::future<int>> tasks(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
      _real_local_shard_num);
  for (size_t i = 0; i < num; ++i) {
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
    task_keys[shard_id].push_back({keys[i], i});
  }
  // std::atomic<uint32_t> missed_keys{0};
522
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
523 524
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
525 526 527 528 529
            [this,
             shard_id,
             &task_keys,
             pull_values,
             value_size,
530 531 532 533 534
             mf_value_size]() -> int {
              auto& keys = task_keys[shard_id];
              auto& local_shard = _local_shards[shard_id];
              float data_buffer[value_size];
              float* data_buffer_ptr = data_buffer;
535
              for (size_t i = 0; i < keys.size(); ++i) {
536 537 538 539 540 541 542 543 544
                uint64_t key = keys[i].first;
                auto itr = local_shard.find(key);
                size_t data_size = value_size - mf_value_size;
                FixedFeatureValue* ret = NULL;
                if (itr == local_shard.end()) {
                  // ++missed_keys;
                  auto& feature_value = local_shard[key];
                  feature_value.resize(data_size);
                  float* data_ptr = feature_value.data();
545
                  _value_accesor->Create(&data_buffer_ptr, 1);
546 547 548 549 550 551 552 553 554 555 556 557 558 559
                  memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
                  ret = &feature_value;
                } else {
                  ret = itr.value_ptr();
                }
                int pull_data_idx = keys[i].second;
                pull_values[pull_data_idx] = (char*)ret;
              }
              return 0;
            });
  }
  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
Z
zhaocaibei123 已提交
560 561 562
  return 0;
}

563 564
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
                                      const float* values,
Z
zhaocaibei123 已提交
565
                                      size_t num) {
566 567
  CostTimer timer("pserver_sparse_update_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
568
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
569
      _real_local_shard_num);
Z
zhaocaibei123 已提交
570
  for (size_t i = 0; i < num; ++i) {
571
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
572 573 574
    task_keys[shard_id].push_back({keys[i], i});
  }

575 576 577 578
  const size_t value_col =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
579
  size_t update_value_col =
580
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
581

582
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
583
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
584 585 586 587 588 589
        [this,
         shard_id,
         value_col,
         mf_value_col,
         update_value_col,
         values,
Z
zhaocaibei123 已提交
590 591
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
592
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
593 594
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
595
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
596 597 598 599
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data =
                values + push_data_idx * update_value_col;
600 601 602
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
603
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
604 605 606
                continue;
              }
              auto value_size = value_col - mf_value_col;
607 608
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
609
              _value_accesor->Create(&data_buffer_ptr, 1);
610 611
              memcpy(feature_value.data(),
                     data_buffer_ptr,
Z
zhaocaibei123 已提交
612
                     value_size * sizeof(float));
613
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
614 615
            }

616 617 618
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
619 620

            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
621
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
622 623 624
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
625
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
Z
zhaocaibei123 已提交
626

627
              if (_value_accesor->NeedExtendMF(data_buffer)) {
628 629
                feature_value.resize(value_col);
                value_data = feature_value.data();
630
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
631 632 633 634 635 636 637 638 639 640 641 642 643 644
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
645
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
646 647
                                      const float** values,
                                      size_t num) {
648
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
649
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
650
      _real_local_shard_num);
Z
zhaocaibei123 已提交
651
  for (size_t i = 0; i < num; ++i) {
652
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
653 654 655
    task_keys[shard_id].push_back({keys[i], i});
  }

656 657 658
  size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
659
  size_t update_value_col =
660
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
661

662 663
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
664 665 666 667 668 669
        [this,
         shard_id,
         value_col,
         mf_value_col,
         update_value_col,
         values,
Z
zhaocaibei123 已提交
670 671
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
672
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
673 674
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
675
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
676 677 678
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data = values[push_data_idx];
679 680 681
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
682
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
683 684 685
                continue;
              }
              auto value_size = value_col - mf_value_col;
686 687
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
688
              _value_accesor->Create(&data_buffer_ptr, 1);
689 690
              memcpy(feature_value.data(),
                     data_buffer_ptr,
Z
zhaocaibei123 已提交
691
                     value_size * sizeof(float));
692
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
693
            }
694 695 696
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
697
            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
698
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
699 700 701
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
702 703
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
              if (_value_accesor->NeedExtendMF(data_buffer)) {
704 705
                feature_value.resize(value_col);
                value_data = feature_value.data();
706
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719 720
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
721
int32_t MemorySparseTable::Flush() { return 0; }
Z
zhaocaibei123 已提交
722

Z
zhaocaibei123 已提交
723 724
int32_t MemorySparseTable::Shrink(const std::string& param) {
  VLOG(0) << "MemorySparseTable::Shrink";
Z
zhaocaibei123 已提交
725
  // TODO(zhaocaibei123): implement with multi-thread
726
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
727
    // Shrink
728 729
    auto& shard = _local_shards[shard_id];
    for (auto it = shard.begin(); it != shard.end();) {
730
      if (_value_accesor->Shrink(it.value().data())) {
731 732 733
        it = shard.erase(it);
      } else {
        ++it;
Z
zhaocaibei123 已提交
734 735 736 737 738 739
      }
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
740
void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
Z
zhaocaibei123 已提交
741 742 743

}  // namespace distributed
}  // namespace paddle