memory_sparse_table.cc 28.0 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"

17
#include <omp.h>
Z
zhaocaibei123 已提交
18

19
#include <sstream>
Z
zhaocaibei123 已提交
20 21

#include "glog/logging.h"
22 23
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/framework/io/fs.h"
Z
zhaocaibei123 已提交
24 25
#include "paddle/fluid/platform/enforce.h"

26 27
DEFINE_bool(pserver_print_missed_key_num_every_push,
            false,
Z
zhaocaibei123 已提交
28
            "pserver_print_missed_key_num_every_push");
29 30
DEFINE_bool(pserver_create_value_when_push,
            true,
Z
zhaocaibei123 已提交
31
            "pserver create value when push");
32 33
DEFINE_bool(pserver_enable_create_feasign_randomly,
            false,
Z
zhaocaibei123 已提交
34 35 36
            "pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");

Z
zhaocaibei123 已提交
37 38 39
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
40
int32_t MemorySparseTable::Initialize() {
41
  _shards_task_pool.resize(_task_pool_size);
42
  for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
43
    _shards_task_pool[i].reset(new ::ThreadPool(1));
Z
zhaocaibei123 已提交
44
  }
45 46 47
  auto& profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_sparse_update_all");
  profiler.register_profiler("pserver_sparse_select_all");
Z
zhaocaibei123 已提交
48
  InitializeValue();
Z
zhaocaibei123 已提交
49 50 51 52
  VLOG(0) << "initalize MemorySparseTable succ";
  return 0;
}

Z
zhaocaibei123 已提交
53
int32_t MemorySparseTable::InitializeValue() {
54 55
  _sparse_table_shard_num = static_cast<int>(_config.shard_num());
  _avg_local_shard_num =
56
      sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
57
  _real_local_shard_num = _avg_local_shard_num;
Z
zhangchunle 已提交
58 59
  if (static_cast<int>(_real_local_shard_num * (_shard_idx + 1)) >
      _sparse_table_shard_num) {
60 61 62 63
    _real_local_shard_num =
        _sparse_table_shard_num - _real_local_shard_num * _shard_idx;
    _real_local_shard_num =
        _real_local_shard_num < 0 ? 0 : _real_local_shard_num;
Z
zhaocaibei123 已提交
64
  }
65 66 67
  VLOG(1) << "memory sparse table _avg_local_shard_num: "
          << _avg_local_shard_num
          << " _real_local_shard_num: " << _real_local_shard_num;
Z
zhaocaibei123 已提交
68

69
  _local_shards.reset(new shard_type[_real_local_shard_num]);
Z
zhaocaibei123 已提交
70 71 72 73

  return 0;
}

Z
zhaocaibei123 已提交
74
int32_t MemorySparseTable::Load(const std::string& path,
Z
zhaocaibei123 已提交
75
                                const std::string& param) {
Z
zhaocaibei123 已提交
76
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
77 78 79 80
  auto file_list = _afs_client.list(table_path);

  std::sort(file_list.begin(), file_list.end());
  for (auto file : file_list) {
Z
zhaocaibei123 已提交
81
    VLOG(1) << "MemorySparseTable::Load() file list: " << file;
Z
zhaocaibei123 已提交
82 83 84
  }

  int load_param = atoi(param.c_str());
85
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
86 87 88 89 90 91 92 93 94 95
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

96
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
97

98
  size_t feature_value_size =
99
      _value_accesor->GetAccessorInfo().size / sizeof(float);
100 101 102 103

  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
104
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
105 106 107 108
    FsChannelConfig channel_config;
    channel_config.path = file_list[file_start_idx + i];
    VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
            << " into local shard " << i;
109
    channel_config.converter = _value_accesor->Converter(load_param).converter;
Z
zhaocaibei123 已提交
110
    channel_config.deconverter =
111
        _value_accesor->Converter(load_param).deconverter;
Z
zhaocaibei123 已提交
112 113 114 115 116 117 118 119 120 121

    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
      char* end = NULL;
122
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
123 124 125 126
      try {
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
127 128
          auto& value = shard[key];
          value.resize(feature_value_size);
129
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
130
          value.resize(parse_size);
Z
zhaocaibei123 已提交
131 132 133 134

          // for debug
          for (int ii = 0; ii < parse_size; ++ii) {
            VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
135
                    << ": " << value.data()[ii] << " local_shard: " << i;
Z
zhaocaibei123 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
          }
        }
        read_channel->close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << channel_config.path << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << channel_config.path << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
152
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
153 154 155 156 157 158 159
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
160
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
161 162 163
  return 0;
}

Z
zhaocaibei123 已提交
164 165 166
int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
                                       const std::string& param) {
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
167
  auto file_list = paddle::framework::localfs_list(table_path);
168
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
169 170 171 172 173 174 175 176 177 178
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

179
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
180

181
  size_t feature_value_size =
182
      _value_accesor->GetAccessorInfo().size / sizeof(float);
Z
zhaocaibei123 已提交
183

184 185 186
  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
187
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
188 189 190 191 192 193 194 195 196
    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      std::ifstream file(file_list[file_start_idx + i]);
      char* end = NULL;
197
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
198 199 200
      try {
        while (std::getline(file, line_data) && line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
201 202
          auto& value = shard[key];
          value.resize(feature_value_size);
203
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
204
          value.resize(parse_size);
Z
zhaocaibei123 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        }
        file.close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << file_list[file_start_idx + i] << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << file_list[file_start_idx + i]
                   << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
221
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
222 223 224 225 226 227 228
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
229
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
230 231 232
  return 0;
}

Z
zhaocaibei123 已提交
233
int32_t MemorySparseTable::Save(const std::string& dirname,
Z
zhaocaibei123 已提交
234 235 236 237
                                const std::string& param) {
  VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
238
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
239 240 241 242
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  std::atomic<uint32_t> feasign_size_all{0};

243
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;
Z
zhaocaibei123 已提交
244

245 246 247
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
248
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
249 250 251
    FsChannelConfig channel_config;
    if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
      channel_config.path =
252 253 254 255 256 257 258 259 260
          paddle::string::format_string("%s/part-%03d-%05d.gz",
                                        table_path.c_str(),
                                        _shard_idx,
                                        file_start_idx + i);
    } else {
      channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
                                                          table_path.c_str(),
                                                          _shard_idx,
                                                          file_start_idx + i);
Z
zhaocaibei123 已提交
261
    }
262
    channel_config.converter = _value_accesor->Converter(save_param).converter;
Z
zhaocaibei123 已提交
263
    channel_config.deconverter =
264
        _value_accesor->Converter(save_param).deconverter;
Z
zhaocaibei123 已提交
265 266 267 268
    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
269
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
270 271 272 273 274 275
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
276
      for (auto it = shard.begin(); it != shard.end(); ++it) {
277 278
        if (_value_accesor->Save(it.value().data(), save_param)) {
          std::string format_value = _value_accesor->ParseToString(
279
              it.value().data(), it.value().size());
280 281
          if (0 != write_channel->write_line(paddle::string::format_string(
                       "%lu %s", it.key(), format_value.c_str()))) {
282 283 284 285 286 287
            ++retry_num;
            is_write_failed = true;
            LOG(ERROR)
                << "MemorySparseTable save prefix failed, retry it! path:"
                << channel_config.path << " , retry_num=" << retry_num;
            break;
Z
zhaocaibei123 已提交
288
          }
289
          ++feasign_size;
Z
zhaocaibei123 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302
        }
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save prefix failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
Z
zhaocaibei123 已提交
303
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
304 305 306 307 308
        LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
309
    for (auto it = shard.begin(); it != shard.end(); ++it) {
310
      _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
Z
zhaocaibei123 已提交
311 312 313 314 315 316 317 318
    }
    LOG(INFO) << "MemorySparseTable save prefix success, path: "
              << channel_config.path;
  }
  // int32 may overflow need to change return value
  return 0;
}

Z
zhaocaibei123 已提交
319 320 321
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
                                       const std::string& param,
                                       const std::string& prefix) {
Z
zhaocaibei123 已提交
322 323
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
324
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
325
  int feasign_cnt = 0;
326 327 328 329 330 331 332
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;

  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  std::atomic<uint32_t> feasign_size_all{0};

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
333
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
334
    feasign_cnt = 0;
335
    auto& shard = _local_shards[i];
336 337 338 339 340 341
    std::string file_name =
        paddle::string::format_string("%s/part-%s-%03d-%05d",
                                      table_path.c_str(),
                                      prefix.c_str(),
                                      _shard_idx,
                                      file_start_idx + i);
Z
zhaocaibei123 已提交
342 343
    std::ofstream os;
    os.open(file_name);
344
    for (auto it = shard.begin(); it != shard.end(); ++it) {
345 346 347
      if (_value_accesor->Save(it.value().data(), save_param)) {
        std::string format_value =
            _value_accesor->ParseToString(it.value().data(), it.value().size());
348 349 350 351 352
        std::string out_line = paddle::string::format_string(
            "%lu %s\n", it.key(), format_value.c_str());
        // VLOG(2) << out_line.c_str();
        os.write(out_line.c_str(), sizeof(char) * out_line.size());
        ++feasign_cnt;
Z
zhaocaibei123 已提交
353 354 355 356 357 358 359 360 361
      }
    }
    os.close();
    LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
              << "feasign_cnt: " << feasign_cnt;
  }
  return 0;
}

Z
zhaocaibei123 已提交
362
int64_t MemorySparseTable::LocalSize() {
363
  int64_t local_size = 0;
364
  for (int i = 0; i < _real_local_shard_num; ++i) {
365 366 367 368
    local_size += _local_shards[i].size();
  }
  return local_size;
}
Z
zhaocaibei123 已提交
369

Z
zhaocaibei123 已提交
370
int64_t MemorySparseTable::LocalMFSize() {
371 372 373
  std::vector<int64_t> size_arr(_real_local_shard_num, 0);
  std::vector<std::future<int>> tasks(_real_local_shard_num);
  int64_t ret_size = 0;
374
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
375 376 377 378 379 380
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &size_arr]() -> int {
              auto& local_shard = _local_shards[shard_id];
              for (auto it = local_shard.begin(); it != local_shard.end();
                   ++it) {
381
                if (_value_accesor->HasMF(it.value().size())) {
382 383 384 385 386 387
                  size_arr[shard_id] += 1;
                }
              }
              return 0;
            });
  }
388
  for (int i = 0; i < _real_local_shard_num; ++i) {
389 390 391 392
    tasks[i].wait();
  }
  for (auto x : size_arr) {
    ret_size += x;
Z
zhaocaibei123 已提交
393
  }
394 395
  return ret_size;
}
Z
zhaocaibei123 已提交
396

Z
zhaocaibei123 已提交
397 398 399
std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
  int64_t feasign_size = LocalSize();
  int64_t mf_size = LocalMFSize();
Z
zhaocaibei123 已提交
400 401 402
  return {feasign_size, mf_size};
}

Y
yaoxuefeng 已提交
403 404 405 406 407
int32_t MemorySparseTable::Pull(TableContext& context) {
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
    char** pull_values = context.pull_context.ptr_values;
    const uint64_t* keys = context.pull_context.keys;
Z
zhaocaibei123 已提交
408
    return PullSparsePtr(pull_values, keys, context.num);
Y
yaoxuefeng 已提交
409 410 411
  } else {
    float* pull_values = context.pull_context.values;
    const PullSparseValue& pull_value = context.pull_context.pull_value;
Z
zhaocaibei123 已提交
412
    return PullSparse(pull_values, pull_value);
Y
yaoxuefeng 已提交
413 414 415 416 417
  }
}

int32_t MemorySparseTable::Push(TableContext& context) {
  CHECK(context.value_type == Sparse);
418
  if (!context.use_ptr) {
419 420
    return PushSparse(
        context.push_context.keys, context.push_context.values, context.num);
421 422
  } else {
    return PushSparse(context.push_context.keys,
423 424
                      context.push_context.ptr_values,
                      context.num);
425
  }
Y
yaoxuefeng 已提交
426 427
}

Z
zhaocaibei123 已提交
428 429
int32_t MemorySparseTable::PullSparse(float* pull_values,
                                      const PullSparseValue& pull_value) {
430 431
  CostTimer timer("pserver_sparse_select_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
432

433 434 435 436
  const size_t value_size =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
437
  size_t select_value_size =
438
      _value_accesor->GetAccessorInfo().select_size / sizeof(float);
Z
zhaocaibei123 已提交
439 440 441
  // std::atomic<uint32_t> missed_keys{0};

  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
442
      _real_local_shard_num);
Z
zhaocaibei123 已提交
443 444
  size_t num = pull_value.numel_;
  for (size_t i = 0; i < num; ++i) {
445 446
    int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) %
                   _avg_local_shard_num;
Z
zhaocaibei123 已提交
447 448
    task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
  }
449
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
450
    tasks[shard_id] =
451
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
452 453 454 455 456 457
            [this,
             shard_id,
             &task_keys,
             value_size,
             pull_values,
             mf_value_size,
Z
zhaocaibei123 已提交
458
             select_value_size]() -> int {
459
              auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
460 461 462 463 464 465
              float data_buffer[value_size];  // NOLINT
              float* data_buffer_ptr = data_buffer;

              auto& keys = task_keys[shard_id];
              for (size_t i = 0; i < keys.size(); i++) {
                uint64_t key = keys[i].first;
466
                auto itr = local_shard.find(key);
Z
zhaocaibei123 已提交
467
                size_t data_size = value_size - mf_value_size;
468
                if (itr == local_shard.end()) {
Z
zhaocaibei123 已提交
469
                  // ++missed_keys;
470
                  if (FLAGS_pserver_create_value_when_push) {
Z
zhaocaibei123 已提交
471 472
                    memset(data_buffer, 0, sizeof(float) * data_size);
                  } else {
473 474 475
                    auto& feature_value = local_shard[key];
                    feature_value.resize(data_size);
                    float* data_ptr = feature_value.data();
476
                    _value_accesor->Create(&data_buffer_ptr, 1);
477 478
                    memcpy(
                        data_ptr, data_buffer_ptr, data_size * sizeof(float));
Z
zhaocaibei123 已提交
479 480
                  }
                } else {
481
                  data_size = itr.value().size();
482 483
                  memcpy(data_buffer_ptr,
                         itr.value().data(),
Z
zhaocaibei123 已提交
484 485
                         data_size * sizeof(float));
                }
486
                for (size_t mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
Z
zhaocaibei123 已提交
487 488 489 490
                  data_buffer[mf_idx] = 0.0;
                }
                auto offset = keys[i].second;
                float* select_data = pull_values + select_value_size * offset;
491 492
                _value_accesor->Select(
                    &select_data, (const float**)&data_buffer_ptr, 1);
Z
zhaocaibei123 已提交
493 494 495 496 497 498 499 500 501 502 503 504
              }

              return 0;
            });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
505
int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
506 507
                                         const uint64_t* keys,
                                         size_t num) {
508
  CostTimer timer("pscore_sparse_select_all");
509 510 511
  size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
512 513 514 515 516 517 518 519 520

  std::vector<std::future<int>> tasks(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
      _real_local_shard_num);
  for (size_t i = 0; i < num; ++i) {
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
    task_keys[shard_id].push_back({keys[i], i});
  }
  // std::atomic<uint32_t> missed_keys{0};
521
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
522 523
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
524 525 526 527 528
            [this,
             shard_id,
             &task_keys,
             pull_values,
             value_size,
529 530 531
             mf_value_size]() -> int {
              auto& keys = task_keys[shard_id];
              auto& local_shard = _local_shards[shard_id];
R
Ruibiao Chen 已提交
532
              float data_buffer[value_size];  // NOLINT
533
              float* data_buffer_ptr = data_buffer;
534
              for (size_t i = 0; i < keys.size(); ++i) {
535 536 537 538 539 540 541 542 543
                uint64_t key = keys[i].first;
                auto itr = local_shard.find(key);
                size_t data_size = value_size - mf_value_size;
                FixedFeatureValue* ret = NULL;
                if (itr == local_shard.end()) {
                  // ++missed_keys;
                  auto& feature_value = local_shard[key];
                  feature_value.resize(data_size);
                  float* data_ptr = feature_value.data();
544
                  _value_accesor->Create(&data_buffer_ptr, 1);
545 546 547 548 549 550
                  memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
                  ret = &feature_value;
                } else {
                  ret = itr.value_ptr();
                }
                int pull_data_idx = keys[i].second;
R
Ruibiao Chen 已提交
551
                pull_values[pull_data_idx] = (char*)ret;  // NOLINT
552 553 554 555 556 557 558
              }
              return 0;
            });
  }
  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
Z
zhaocaibei123 已提交
559 560 561
  return 0;
}

562 563
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
                                      const float* values,
Z
zhaocaibei123 已提交
564
                                      size_t num) {
565 566
  CostTimer timer("pserver_sparse_update_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
567
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
568
      _real_local_shard_num);
Z
zhaocaibei123 已提交
569
  for (size_t i = 0; i < num; ++i) {
570
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
571 572 573
    task_keys[shard_id].push_back({keys[i], i});
  }

574 575 576 577
  const size_t value_col =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
578
  size_t update_value_col =
579
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
580

581
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
582
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
583 584 585 586 587 588
        [this,
         shard_id,
         value_col,
         mf_value_col,
         update_value_col,
         values,
Z
zhaocaibei123 已提交
589 590
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
591
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
592 593
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
594
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
595 596 597 598
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data =
                values + push_data_idx * update_value_col;
599 600 601
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
602
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
603 604 605
                continue;
              }
              auto value_size = value_col - mf_value_col;
606 607
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
608
              _value_accesor->Create(&data_buffer_ptr, 1);
609 610
              memcpy(feature_value.data(),
                     data_buffer_ptr,
Z
zhaocaibei123 已提交
611
                     value_size * sizeof(float));
612
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
613 614
            }

615 616 617
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
618 619

            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
620
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
621 622 623
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
624
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
Z
zhaocaibei123 已提交
625

626
              if (_value_accesor->NeedExtendMF(data_buffer)) {
627 628
                feature_value.resize(value_col);
                value_data = feature_value.data();
629
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
630 631 632 633 634 635 636 637 638 639 640 641 642 643
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
644
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
645 646
                                      const float** values,
                                      size_t num) {
647
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
648
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
649
      _real_local_shard_num);
Z
zhaocaibei123 已提交
650
  for (size_t i = 0; i < num; ++i) {
651
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
652 653 654
    task_keys[shard_id].push_back({keys[i], i});
  }

655 656 657
  size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
658
  size_t update_value_col =
659
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
660

661 662
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
663 664 665 666 667 668
        [this,
         shard_id,
         value_col,
         mf_value_col,
         update_value_col,
         values,
Z
zhaocaibei123 已提交
669 670
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
671
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
672 673
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
674
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
675 676 677
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data = values[push_data_idx];
678 679 680
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
681
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
682 683 684
                continue;
              }
              auto value_size = value_col - mf_value_col;
685 686
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
687
              _value_accesor->Create(&data_buffer_ptr, 1);
688 689
              memcpy(feature_value.data(),
                     data_buffer_ptr,
Z
zhaocaibei123 已提交
690
                     value_size * sizeof(float));
691
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
692
            }
693 694 695
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
696
            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
697
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
698 699 700
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
701 702
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
              if (_value_accesor->NeedExtendMF(data_buffer)) {
703 704
                feature_value.resize(value_col);
                value_data = feature_value.data();
705
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
706 707 708 709 710 711 712 713 714 715 716 717 718 719
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
720
int32_t MemorySparseTable::Flush() { return 0; }
Z
zhaocaibei123 已提交
721

Z
zhaocaibei123 已提交
722 723
int32_t MemorySparseTable::Shrink(const std::string& param) {
  VLOG(0) << "MemorySparseTable::Shrink";
Z
zhaocaibei123 已提交
724
  // TODO(zhaocaibei123): implement with multi-thread
725
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
726
    // Shrink
727 728
    auto& shard = _local_shards[shard_id];
    for (auto it = shard.begin(); it != shard.end();) {
729
      if (_value_accesor->Shrink(it.value().data())) {
730 731 732
        it = shard.erase(it);
      } else {
        ++it;
Z
zhaocaibei123 已提交
733 734 735 736 737 738
      }
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
739
void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
Z
zhaocaibei123 已提交
740 741 742

}  // namespace distributed
}  // namespace paddle