memory_sparse_table.cc 27.3 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"

17
#include <omp.h>
Z
zhaocaibei123 已提交
18

19
#include <sstream>
Z
zhaocaibei123 已提交
20 21 22

#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
23 24
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/framework/io/fs.h"
Z
zhaocaibei123 已提交
25 26
#include "paddle/fluid/platform/enforce.h"

Z
zhaocaibei123 已提交
27 28 29 30 31 32 33 34
DEFINE_bool(pserver_print_missed_key_num_every_push, false,
            "pserver_print_missed_key_num_every_push");
DEFINE_bool(pserver_create_value_when_push, true,
            "pserver create value when push");
DEFINE_bool(pserver_enable_create_feasign_randomly, false,
            "pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");

Z
zhaocaibei123 已提交
35 36 37
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
38
int32_t MemorySparseTable::Initialize() {
39
  _shards_task_pool.resize(_task_pool_size);
40
  for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
41
    _shards_task_pool[i].reset(new ::ThreadPool(1));
Z
zhaocaibei123 已提交
42
  }
43 44 45
  auto& profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_sparse_update_all");
  profiler.register_profiler("pserver_sparse_select_all");
Z
zhaocaibei123 已提交
46
  InitializeValue();
Z
zhaocaibei123 已提交
47 48 49 50
  VLOG(0) << "initalize MemorySparseTable succ";
  return 0;
}

Z
zhaocaibei123 已提交
51
int32_t MemorySparseTable::InitializeValue() {
52 53
  _sparse_table_shard_num = static_cast<int>(_config.shard_num());
  _avg_local_shard_num =
54
      sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
55 56 57 58 59 60
  _real_local_shard_num = _avg_local_shard_num;
  if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) {
    _real_local_shard_num =
        _sparse_table_shard_num - _real_local_shard_num * _shard_idx;
    _real_local_shard_num =
        _real_local_shard_num < 0 ? 0 : _real_local_shard_num;
Z
zhaocaibei123 已提交
61
  }
62 63 64
  VLOG(1) << "memory sparse table _avg_local_shard_num: "
          << _avg_local_shard_num
          << " _real_local_shard_num: " << _real_local_shard_num;
Z
zhaocaibei123 已提交
65

66
  _local_shards.reset(new shard_type[_real_local_shard_num]);
Z
zhaocaibei123 已提交
67 68 69 70

  return 0;
}

Z
zhaocaibei123 已提交
71
int32_t MemorySparseTable::Load(const std::string& path,
Z
zhaocaibei123 已提交
72
                                const std::string& param) {
Z
zhaocaibei123 已提交
73
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
74 75 76 77
  auto file_list = _afs_client.list(table_path);

  std::sort(file_list.begin(), file_list.end());
  for (auto file : file_list) {
Z
zhaocaibei123 已提交
78
    VLOG(1) << "MemorySparseTable::Load() file list: " << file;
Z
zhaocaibei123 已提交
79 80 81
  }

  int load_param = atoi(param.c_str());
82
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
83 84 85 86 87 88 89 90 91 92
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

93
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
94

95
  size_t feature_value_size =
96
      _value_accesor->GetAccessorInfo().size / sizeof(float);
97 98 99 100

  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
101
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
102 103 104 105
    FsChannelConfig channel_config;
    channel_config.path = file_list[file_start_idx + i];
    VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
            << " into local shard " << i;
106
    channel_config.converter = _value_accesor->Converter(load_param).converter;
Z
zhaocaibei123 已提交
107
    channel_config.deconverter =
108
        _value_accesor->Converter(load_param).deconverter;
Z
zhaocaibei123 已提交
109 110 111 112 113 114 115 116 117 118

    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
      char* end = NULL;
119
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
120 121 122 123
      try {
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
124 125
          auto& value = shard[key];
          value.resize(feature_value_size);
126
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
127
          value.resize(parse_size);
Z
zhaocaibei123 已提交
128 129 130 131

          // for debug
          for (int ii = 0; ii < parse_size; ++ii) {
            VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
132
                    << ": " << value.data()[ii] << " local_shard: " << i;
Z
zhaocaibei123 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
          }
        }
        read_channel->close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << channel_config.path << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << channel_config.path << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
149
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
150 151 152 153 154 155 156
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
157
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
158 159 160
  return 0;
}

Z
zhaocaibei123 已提交
161 162 163
int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
                                       const std::string& param) {
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
164 165 166
  auto file_list = paddle::framework::localfs_list(table_path);

  int load_param = atoi(param.c_str());
167
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
168 169 170 171 172 173 174 175 176 177
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

178
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
179

180
  size_t feature_value_size =
181
      _value_accesor->GetAccessorInfo().size / sizeof(float);
Z
zhaocaibei123 已提交
182

183 184 185
  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
186
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
187 188 189 190 191 192 193 194 195
    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      std::ifstream file(file_list[file_start_idx + i]);
      char* end = NULL;
196
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
197 198 199
      try {
        while (std::getline(file, line_data) && line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
200 201
          auto& value = shard[key];
          value.resize(feature_value_size);
202
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
203
          value.resize(parse_size);
Z
zhaocaibei123 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        }
        file.close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << file_list[file_start_idx + i] << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << file_list[file_start_idx + i]
                   << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
220
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
221 222 223 224 225 226 227
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
228
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
229 230 231
  return 0;
}

Z
zhaocaibei123 已提交
232
int32_t MemorySparseTable::Save(const std::string& dirname,
Z
zhaocaibei123 已提交
233 234 235 236
                                const std::string& param) {
  VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
237
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
238 239 240 241
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  std::atomic<uint32_t> feasign_size_all{0};

242
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;
Z
zhaocaibei123 已提交
243

244 245 246
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
247
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
248 249 250 251 252 253 254 255 256 257
    FsChannelConfig channel_config;
    if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
      channel_config.path = paddle::string::format_string(
          "%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx,
          file_start_idx + i);
    } else {
      channel_config.path =
          paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(),
                                        _shard_idx, file_start_idx + i);
    }
258
    channel_config.converter = _value_accesor->Converter(save_param).converter;
Z
zhaocaibei123 已提交
259
    channel_config.deconverter =
260
        _value_accesor->Converter(save_param).deconverter;
Z
zhaocaibei123 已提交
261 262 263 264
    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
265
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
266 267 268 269 270 271
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
272
      for (auto it = shard.begin(); it != shard.end(); ++it) {
273 274
        if (_value_accesor->Save(it.value().data(), save_param)) {
          std::string format_value = _value_accesor->ParseToString(
275
              it.value().data(), it.value().size());
276 277
          if (0 != write_channel->write_line(paddle::string::format_string(
                       "%lu %s", it.key(), format_value.c_str()))) {
278 279 280 281 282 283
            ++retry_num;
            is_write_failed = true;
            LOG(ERROR)
                << "MemorySparseTable save prefix failed, retry it! path:"
                << channel_config.path << " , retry_num=" << retry_num;
            break;
Z
zhaocaibei123 已提交
284
          }
285
          ++feasign_size;
Z
zhaocaibei123 已提交
286 287 288 289 290 291 292 293 294 295 296 297 298
        }
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save prefix failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
Z
zhaocaibei123 已提交
299
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
300 301 302 303 304
        LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
305
    for (auto it = shard.begin(); it != shard.end(); ++it) {
306
      _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
Z
zhaocaibei123 已提交
307 308 309 310 311 312 313 314
    }
    LOG(INFO) << "MemorySparseTable save prefix success, path: "
              << channel_config.path;
  }
  // int32 may overflow need to change return value
  return 0;
}

Z
zhaocaibei123 已提交
315 316 317
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
                                       const std::string& param,
                                       const std::string& prefix) {
Z
zhaocaibei123 已提交
318 319
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
320
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
321
  int feasign_cnt = 0;
322 323 324 325 326 327 328
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;

  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  std::atomic<uint32_t> feasign_size_all{0};

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
329
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
330
    feasign_cnt = 0;
331
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
332 333 334 335 336
    std::string file_name = paddle::string::format_string(
        "%s/part-%s-%03d-%05d", table_path.c_str(), prefix.c_str(), _shard_idx,
        file_start_idx + i);
    std::ofstream os;
    os.open(file_name);
337
    for (auto it = shard.begin(); it != shard.end(); ++it) {
338 339 340
      if (_value_accesor->Save(it.value().data(), save_param)) {
        std::string format_value =
            _value_accesor->ParseToString(it.value().data(), it.value().size());
341 342 343 344 345
        std::string out_line = paddle::string::format_string(
            "%lu %s\n", it.key(), format_value.c_str());
        // VLOG(2) << out_line.c_str();
        os.write(out_line.c_str(), sizeof(char) * out_line.size());
        ++feasign_cnt;
Z
zhaocaibei123 已提交
346 347 348 349 350 351 352 353 354
      }
    }
    os.close();
    LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
              << "feasign_cnt: " << feasign_cnt;
  }
  return 0;
}

Z
zhaocaibei123 已提交
355
int64_t MemorySparseTable::LocalSize() {
356
  int64_t local_size = 0;
357
  for (int i = 0; i < _real_local_shard_num; ++i) {
358 359 360 361
    local_size += _local_shards[i].size();
  }
  return local_size;
}
Z
zhaocaibei123 已提交
362

Z
zhaocaibei123 已提交
363
int64_t MemorySparseTable::LocalMFSize() {
364 365 366
  std::vector<int64_t> size_arr(_real_local_shard_num, 0);
  std::vector<std::future<int>> tasks(_real_local_shard_num);
  int64_t ret_size = 0;
367
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
368 369 370 371 372 373
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &size_arr]() -> int {
              auto& local_shard = _local_shards[shard_id];
              for (auto it = local_shard.begin(); it != local_shard.end();
                   ++it) {
374
                if (_value_accesor->HasMF(it.value().size())) {
375 376 377 378 379 380
                  size_arr[shard_id] += 1;
                }
              }
              return 0;
            });
  }
381
  for (int i = 0; i < _real_local_shard_num; ++i) {
382 383 384 385
    tasks[i].wait();
  }
  for (auto x : size_arr) {
    ret_size += x;
Z
zhaocaibei123 已提交
386
  }
387 388
  return ret_size;
}
Z
zhaocaibei123 已提交
389

Z
zhaocaibei123 已提交
390 391 392
std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
  int64_t feasign_size = LocalSize();
  int64_t mf_size = LocalMFSize();
Z
zhaocaibei123 已提交
393 394 395
  return {feasign_size, mf_size};
}

Y
yaoxuefeng 已提交
396 397 398 399 400
int32_t MemorySparseTable::Pull(TableContext& context) {
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
    char** pull_values = context.pull_context.ptr_values;
    const uint64_t* keys = context.pull_context.keys;
Z
zhaocaibei123 已提交
401
    return PullSparsePtr(pull_values, keys, context.num);
Y
yaoxuefeng 已提交
402 403 404
  } else {
    float* pull_values = context.pull_context.values;
    const PullSparseValue& pull_value = context.pull_context.pull_value;
Z
zhaocaibei123 已提交
405
    return PullSparse(pull_values, pull_value);
Y
yaoxuefeng 已提交
406 407 408 409 410
  }
}

int32_t MemorySparseTable::Push(TableContext& context) {
  CHECK(context.value_type == Sparse);
411 412 413 414 415 416 417
  if (!context.use_ptr) {
    return PushSparse(context.push_context.keys, context.push_context.values,
                      context.num);
  } else {
    return PushSparse(context.push_context.keys,
                      context.push_context.ptr_values, context.num);
  }
Y
yaoxuefeng 已提交
418 419
}

Z
zhaocaibei123 已提交
420 421
int32_t MemorySparseTable::PullSparse(float* pull_values,
                                      const PullSparseValue& pull_value) {
422 423
  CostTimer timer("pserver_sparse_select_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
424

425 426 427 428
  const size_t value_size =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
429
  size_t select_value_size =
430
      _value_accesor->GetAccessorInfo().select_size / sizeof(float);
Z
zhaocaibei123 已提交
431 432 433
  // std::atomic<uint32_t> missed_keys{0};

  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
434
      _real_local_shard_num);
Z
zhaocaibei123 已提交
435 436
  size_t num = pull_value.numel_;
  for (size_t i = 0; i < num; ++i) {
437 438
    int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) %
                   _avg_local_shard_num;
Z
zhaocaibei123 已提交
439 440
    task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
  }
441
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
442
    tasks[shard_id] =
443
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
Z
zhaocaibei123 已提交
444 445
            [this, shard_id, &task_keys, value_size, pull_values, mf_value_size,
             select_value_size]() -> int {
446
              auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
447 448 449 450 451 452
              float data_buffer[value_size];  // NOLINT
              float* data_buffer_ptr = data_buffer;

              auto& keys = task_keys[shard_id];
              for (size_t i = 0; i < keys.size(); i++) {
                uint64_t key = keys[i].first;
453
                auto itr = local_shard.find(key);
Z
zhaocaibei123 已提交
454
                size_t data_size = value_size - mf_value_size;
455
                if (itr == local_shard.end()) {
Z
zhaocaibei123 已提交
456
                  // ++missed_keys;
457
                  if (FLAGS_pserver_create_value_when_push) {
Z
zhaocaibei123 已提交
458 459
                    memset(data_buffer, 0, sizeof(float) * data_size);
                  } else {
460 461 462
                    auto& feature_value = local_shard[key];
                    feature_value.resize(data_size);
                    float* data_ptr = feature_value.data();
463
                    _value_accesor->Create(&data_buffer_ptr, 1);
Z
zhaocaibei123 已提交
464 465 466 467
                    memcpy(data_ptr, data_buffer_ptr,
                           data_size * sizeof(float));
                  }
                } else {
468 469
                  data_size = itr.value().size();
                  memcpy(data_buffer_ptr, itr.value().data(),
Z
zhaocaibei123 已提交
470 471
                         data_size * sizeof(float));
                }
472
                for (size_t mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
Z
zhaocaibei123 已提交
473 474 475 476
                  data_buffer[mf_idx] = 0.0;
                }
                auto offset = keys[i].second;
                float* select_data = pull_values + select_value_size * offset;
477
                _value_accesor->Select(&select_data,
Z
zhaocaibei123 已提交
478 479 480 481 482 483 484 485 486 487 488 489 490
                                       (const float**)&data_buffer_ptr, 1);
              }

              return 0;
            });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
491 492
int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
                                         const uint64_t* keys, size_t num) {
493
  CostTimer timer("pscore_sparse_select_all");
494 495 496
  size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
497 498 499 500 501 502 503 504 505

  std::vector<std::future<int>> tasks(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
      _real_local_shard_num);
  for (size_t i = 0; i < num; ++i) {
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
    task_keys[shard_id].push_back({keys[i], i});
  }
  // std::atomic<uint32_t> missed_keys{0};
506
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
507 508 509 510 511 512 513 514
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &task_keys, pull_values, value_size,
             mf_value_size]() -> int {
              auto& keys = task_keys[shard_id];
              auto& local_shard = _local_shards[shard_id];
              float data_buffer[value_size];
              float* data_buffer_ptr = data_buffer;
515
              for (size_t i = 0; i < keys.size(); ++i) {
516 517 518 519 520 521 522 523 524
                uint64_t key = keys[i].first;
                auto itr = local_shard.find(key);
                size_t data_size = value_size - mf_value_size;
                FixedFeatureValue* ret = NULL;
                if (itr == local_shard.end()) {
                  // ++missed_keys;
                  auto& feature_value = local_shard[key];
                  feature_value.resize(data_size);
                  float* data_ptr = feature_value.data();
525
                  _value_accesor->Create(&data_buffer_ptr, 1);
526 527 528 529 530 531 532 533 534 535 536 537 538 539
                  memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
                  ret = &feature_value;
                } else {
                  ret = itr.value_ptr();
                }
                int pull_data_idx = keys[i].second;
                pull_values[pull_data_idx] = (char*)ret;
              }
              return 0;
            });
  }
  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
Z
zhaocaibei123 已提交
540 541 542
  return 0;
}

Z
zhaocaibei123 已提交
543 544
int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
                                      size_t num) {
545 546
  CostTimer timer("pserver_sparse_update_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
547
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
548
      _real_local_shard_num);
Z
zhaocaibei123 已提交
549
  for (size_t i = 0; i < num; ++i) {
550
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
551 552 553
    task_keys[shard_id].push_back({keys[i], i});
  }

554 555 556 557
  const size_t value_col =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
558
  size_t update_value_col =
559
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
560

561
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
562
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
Z
zhaocaibei123 已提交
563 564 565
        [this, shard_id, value_col, mf_value_col, update_value_col, values,
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
566
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
567 568
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
569
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
570 571 572 573
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data =
                values + push_data_idx * update_value_col;
574 575 576
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
577
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
578 579 580
                continue;
              }
              auto value_size = value_col - mf_value_col;
581 582
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
583
              _value_accesor->Create(&data_buffer_ptr, 1);
584
              memcpy(feature_value.data(), data_buffer_ptr,
Z
zhaocaibei123 已提交
585
                     value_size * sizeof(float));
586
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
587 588
            }

589 590 591
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
592 593

            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
594
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
595 596 597
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
598
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
Z
zhaocaibei123 已提交
599

600
              if (_value_accesor->NeedExtendMF(data_buffer)) {
601 602
                feature_value.resize(value_col);
                value_data = feature_value.data();
603
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
604 605 606 607 608 609 610 611 612 613 614 615 616 617
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
618 619
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
                                      const float** values, size_t num) {
620
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
621
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
622
      _real_local_shard_num);
Z
zhaocaibei123 已提交
623
  for (size_t i = 0; i < num; ++i) {
624
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
625 626 627
    task_keys[shard_id].push_back({keys[i], i});
  }

628 629 630
  size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
631
  size_t update_value_col =
632
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
633

634 635
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
Z
zhaocaibei123 已提交
636 637 638
        [this, shard_id, value_col, mf_value_col, update_value_col, values,
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
639
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
640 641
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
642
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
643 644 645
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data = values[push_data_idx];
646 647 648
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
649
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
650 651 652
                continue;
              }
              auto value_size = value_col - mf_value_col;
653 654
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
655
              _value_accesor->Create(&data_buffer_ptr, 1);
656
              memcpy(feature_value.data(), data_buffer_ptr,
Z
zhaocaibei123 已提交
657
                     value_size * sizeof(float));
658
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
659
            }
660 661 662
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
663
            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
664
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
665 666 667
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
668 669
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
              if (_value_accesor->NeedExtendMF(data_buffer)) {
670 671
                feature_value.resize(value_col);
                value_data = feature_value.data();
672
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
687
int32_t MemorySparseTable::Flush() { return 0; }
Z
zhaocaibei123 已提交
688

Z
zhaocaibei123 已提交
689 690
int32_t MemorySparseTable::Shrink(const std::string& param) {
  VLOG(0) << "MemorySparseTable::Shrink";
Z
zhaocaibei123 已提交
691
  // TODO(zhaocaibei123): implement with multi-thread
692
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
693
    // Shrink
694 695
    auto& shard = _local_shards[shard_id];
    for (auto it = shard.begin(); it != shard.end();) {
696
      if (_value_accesor->Shrink(it.value().data())) {
697 698 699
        it = shard.erase(it);
      } else {
        ++it;
Z
zhaocaibei123 已提交
700 701 702 703 704 705
      }
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
706
void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
Z
zhaocaibei123 已提交
707 708 709

}  // namespace distributed
}  // namespace paddle