memory_sparse_table.cc 27.3 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"

17
#include <omp.h>
Z
zhaocaibei123 已提交
18

19
#include <sstream>
Z
zhaocaibei123 已提交
20 21 22

#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
23 24
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/framework/io/fs.h"
Z
zhaocaibei123 已提交
25 26
#include "paddle/fluid/platform/enforce.h"

Z
zhaocaibei123 已提交
27 28 29 30 31 32 33 34
DEFINE_bool(pserver_print_missed_key_num_every_push, false,
            "pserver_print_missed_key_num_every_push");
DEFINE_bool(pserver_create_value_when_push, true,
            "pserver create value when push");
DEFINE_bool(pserver_enable_create_feasign_randomly, false,
            "pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");

Z
zhaocaibei123 已提交
35 36 37
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
38
int32_t MemorySparseTable::Initialize() {
39
  _shards_task_pool.resize(_task_pool_size);
40
  for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
41
    _shards_task_pool[i].reset(new ::ThreadPool(1));
Z
zhaocaibei123 已提交
42
  }
43 44 45
  auto& profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_sparse_update_all");
  profiler.register_profiler("pserver_sparse_select_all");
Z
zhaocaibei123 已提交
46
  InitializeValue();
Z
zhaocaibei123 已提交
47 48 49 50
  VLOG(0) << "initalize MemorySparseTable succ";
  return 0;
}

Z
zhaocaibei123 已提交
51
int32_t MemorySparseTable::InitializeValue() {
52 53
  _sparse_table_shard_num = static_cast<int>(_config.shard_num());
  _avg_local_shard_num =
54
      sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
55
  _real_local_shard_num = _avg_local_shard_num;
Z
zhangchunle 已提交
56 57
  if (static_cast<int>(_real_local_shard_num * (_shard_idx + 1)) >
      _sparse_table_shard_num) {
58 59 60 61
    _real_local_shard_num =
        _sparse_table_shard_num - _real_local_shard_num * _shard_idx;
    _real_local_shard_num =
        _real_local_shard_num < 0 ? 0 : _real_local_shard_num;
Z
zhaocaibei123 已提交
62
  }
63 64 65
  VLOG(1) << "memory sparse table _avg_local_shard_num: "
          << _avg_local_shard_num
          << " _real_local_shard_num: " << _real_local_shard_num;
Z
zhaocaibei123 已提交
66

67
  _local_shards.reset(new shard_type[_real_local_shard_num]);
Z
zhaocaibei123 已提交
68 69 70 71

  return 0;
}

Z
zhaocaibei123 已提交
72
int32_t MemorySparseTable::Load(const std::string& path,
Z
zhaocaibei123 已提交
73
                                const std::string& param) {
Z
zhaocaibei123 已提交
74
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
75 76 77 78
  auto file_list = _afs_client.list(table_path);

  std::sort(file_list.begin(), file_list.end());
  for (auto file : file_list) {
Z
zhaocaibei123 已提交
79
    VLOG(1) << "MemorySparseTable::Load() file list: " << file;
Z
zhaocaibei123 已提交
80 81 82
  }

  int load_param = atoi(param.c_str());
83
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
84 85 86 87 88 89 90 91 92 93
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

94
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
95

96
  size_t feature_value_size =
97
      _value_accesor->GetAccessorInfo().size / sizeof(float);
98 99 100 101

  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
102
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
103 104 105 106
    FsChannelConfig channel_config;
    channel_config.path = file_list[file_start_idx + i];
    VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
            << " into local shard " << i;
107
    channel_config.converter = _value_accesor->Converter(load_param).converter;
Z
zhaocaibei123 已提交
108
    channel_config.deconverter =
109
        _value_accesor->Converter(load_param).deconverter;
Z
zhaocaibei123 已提交
110 111 112 113 114 115 116 117 118 119

    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
      char* end = NULL;
120
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
121 122 123 124
      try {
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
125 126
          auto& value = shard[key];
          value.resize(feature_value_size);
127
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
128
          value.resize(parse_size);
Z
zhaocaibei123 已提交
129 130 131 132

          // for debug
          for (int ii = 0; ii < parse_size; ++ii) {
            VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
133
                    << ": " << value.data()[ii] << " local_shard: " << i;
Z
zhaocaibei123 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
          }
        }
        read_channel->close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << channel_config.path << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << channel_config.path << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
150
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
151 152 153 154 155 156 157
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
158
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
159 160 161
  return 0;
}

Z
zhaocaibei123 已提交
162 163 164
int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
                                       const std::string& param) {
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
165
  auto file_list = paddle::framework::localfs_list(table_path);
166
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
167 168 169 170 171 172 173 174 175 176
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

177
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
178

179
  size_t feature_value_size =
180
      _value_accesor->GetAccessorInfo().size / sizeof(float);
Z
zhaocaibei123 已提交
181

182 183 184
  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
185
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
186 187 188 189 190 191 192 193 194
    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      std::ifstream file(file_list[file_start_idx + i]);
      char* end = NULL;
195
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
196 197 198
      try {
        while (std::getline(file, line_data) && line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
199 200
          auto& value = shard[key];
          value.resize(feature_value_size);
201
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
202
          value.resize(parse_size);
Z
zhaocaibei123 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
        }
        file.close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << file_list[file_start_idx + i] << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << file_list[file_start_idx + i]
                   << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
219
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
220 221 222 223 224 225 226
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
227
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
228 229 230
  return 0;
}

Z
zhaocaibei123 已提交
231
int32_t MemorySparseTable::Save(const std::string& dirname,
Z
zhaocaibei123 已提交
232 233 234 235
                                const std::string& param) {
  VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
236
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
237 238 239 240
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  std::atomic<uint32_t> feasign_size_all{0};

241
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;
Z
zhaocaibei123 已提交
242

243 244 245
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
246
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
247 248 249 250 251 252 253 254 255 256
    FsChannelConfig channel_config;
    if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
      channel_config.path = paddle::string::format_string(
          "%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx,
          file_start_idx + i);
    } else {
      channel_config.path =
          paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(),
                                        _shard_idx, file_start_idx + i);
    }
257
    channel_config.converter = _value_accesor->Converter(save_param).converter;
Z
zhaocaibei123 已提交
258
    channel_config.deconverter =
259
        _value_accesor->Converter(save_param).deconverter;
Z
zhaocaibei123 已提交
260 261 262 263
    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
264
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
265 266 267 268 269 270
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
271
      for (auto it = shard.begin(); it != shard.end(); ++it) {
272 273
        if (_value_accesor->Save(it.value().data(), save_param)) {
          std::string format_value = _value_accesor->ParseToString(
274
              it.value().data(), it.value().size());
275 276
          if (0 != write_channel->write_line(paddle::string::format_string(
                       "%lu %s", it.key(), format_value.c_str()))) {
277 278 279 280 281 282
            ++retry_num;
            is_write_failed = true;
            LOG(ERROR)
                << "MemorySparseTable save prefix failed, retry it! path:"
                << channel_config.path << " , retry_num=" << retry_num;
            break;
Z
zhaocaibei123 已提交
283
          }
284
          ++feasign_size;
Z
zhaocaibei123 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297
        }
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save prefix failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
Z
zhaocaibei123 已提交
298
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
299 300 301 302 303
        LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
304
    for (auto it = shard.begin(); it != shard.end(); ++it) {
305
      _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
Z
zhaocaibei123 已提交
306 307 308 309 310 311 312 313
    }
    LOG(INFO) << "MemorySparseTable save prefix success, path: "
              << channel_config.path;
  }
  // int32 may overflow need to change return value
  return 0;
}

Z
zhaocaibei123 已提交
314 315 316
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
                                       const std::string& param,
                                       const std::string& prefix) {
Z
zhaocaibei123 已提交
317 318
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
319
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
320
  int feasign_cnt = 0;
321 322 323 324 325 326 327
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;

  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  std::atomic<uint32_t> feasign_size_all{0};

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
328
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
329
    feasign_cnt = 0;
330
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
331 332 333 334 335
    std::string file_name = paddle::string::format_string(
        "%s/part-%s-%03d-%05d", table_path.c_str(), prefix.c_str(), _shard_idx,
        file_start_idx + i);
    std::ofstream os;
    os.open(file_name);
336
    for (auto it = shard.begin(); it != shard.end(); ++it) {
337 338 339
      if (_value_accesor->Save(it.value().data(), save_param)) {
        std::string format_value =
            _value_accesor->ParseToString(it.value().data(), it.value().size());
340 341 342 343 344
        std::string out_line = paddle::string::format_string(
            "%lu %s\n", it.key(), format_value.c_str());
        // VLOG(2) << out_line.c_str();
        os.write(out_line.c_str(), sizeof(char) * out_line.size());
        ++feasign_cnt;
Z
zhaocaibei123 已提交
345 346 347 348 349 350 351 352 353
      }
    }
    os.close();
    LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
              << "feasign_cnt: " << feasign_cnt;
  }
  return 0;
}

Z
zhaocaibei123 已提交
354
int64_t MemorySparseTable::LocalSize() {
355
  int64_t local_size = 0;
356
  for (int i = 0; i < _real_local_shard_num; ++i) {
357 358 359 360
    local_size += _local_shards[i].size();
  }
  return local_size;
}
Z
zhaocaibei123 已提交
361

Z
zhaocaibei123 已提交
362
int64_t MemorySparseTable::LocalMFSize() {
363 364 365
  std::vector<int64_t> size_arr(_real_local_shard_num, 0);
  std::vector<std::future<int>> tasks(_real_local_shard_num);
  int64_t ret_size = 0;
366
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
367 368 369 370 371 372
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &size_arr]() -> int {
              auto& local_shard = _local_shards[shard_id];
              for (auto it = local_shard.begin(); it != local_shard.end();
                   ++it) {
373
                if (_value_accesor->HasMF(it.value().size())) {
374 375 376 377 378 379
                  size_arr[shard_id] += 1;
                }
              }
              return 0;
            });
  }
380
  for (int i = 0; i < _real_local_shard_num; ++i) {
381 382 383 384
    tasks[i].wait();
  }
  for (auto x : size_arr) {
    ret_size += x;
Z
zhaocaibei123 已提交
385
  }
386 387
  return ret_size;
}
Z
zhaocaibei123 已提交
388

Z
zhaocaibei123 已提交
389 390 391
std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
  int64_t feasign_size = LocalSize();
  int64_t mf_size = LocalMFSize();
Z
zhaocaibei123 已提交
392 393 394
  return {feasign_size, mf_size};
}

Y
yaoxuefeng 已提交
395 396 397 398 399
int32_t MemorySparseTable::Pull(TableContext& context) {
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
    char** pull_values = context.pull_context.ptr_values;
    const uint64_t* keys = context.pull_context.keys;
Z
zhaocaibei123 已提交
400
    return PullSparsePtr(pull_values, keys, context.num);
Y
yaoxuefeng 已提交
401 402 403
  } else {
    float* pull_values = context.pull_context.values;
    const PullSparseValue& pull_value = context.pull_context.pull_value;
Z
zhaocaibei123 已提交
404
    return PullSparse(pull_values, pull_value);
Y
yaoxuefeng 已提交
405 406 407 408 409
  }
}

int32_t MemorySparseTable::Push(TableContext& context) {
  CHECK(context.value_type == Sparse);
410 411 412 413 414 415 416
  if (!context.use_ptr) {
    return PushSparse(context.push_context.keys, context.push_context.values,
                      context.num);
  } else {
    return PushSparse(context.push_context.keys,
                      context.push_context.ptr_values, context.num);
  }
Y
yaoxuefeng 已提交
417 418
}

Z
zhaocaibei123 已提交
419 420
int32_t MemorySparseTable::PullSparse(float* pull_values,
                                      const PullSparseValue& pull_value) {
421 422
  CostTimer timer("pserver_sparse_select_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
423

424 425 426 427
  const size_t value_size =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
428
  size_t select_value_size =
429
      _value_accesor->GetAccessorInfo().select_size / sizeof(float);
Z
zhaocaibei123 已提交
430 431 432
  // std::atomic<uint32_t> missed_keys{0};

  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
433
      _real_local_shard_num);
Z
zhaocaibei123 已提交
434 435
  size_t num = pull_value.numel_;
  for (size_t i = 0; i < num; ++i) {
436 437
    int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) %
                   _avg_local_shard_num;
Z
zhaocaibei123 已提交
438 439
    task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
  }
440
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
441
    tasks[shard_id] =
442
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
Z
zhaocaibei123 已提交
443 444
            [this, shard_id, &task_keys, value_size, pull_values, mf_value_size,
             select_value_size]() -> int {
445
              auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
446 447 448 449 450 451
              float data_buffer[value_size];  // NOLINT
              float* data_buffer_ptr = data_buffer;

              auto& keys = task_keys[shard_id];
              for (size_t i = 0; i < keys.size(); i++) {
                uint64_t key = keys[i].first;
452
                auto itr = local_shard.find(key);
Z
zhaocaibei123 已提交
453
                size_t data_size = value_size - mf_value_size;
454
                if (itr == local_shard.end()) {
Z
zhaocaibei123 已提交
455
                  // ++missed_keys;
456
                  if (FLAGS_pserver_create_value_when_push) {
Z
zhaocaibei123 已提交
457 458
                    memset(data_buffer, 0, sizeof(float) * data_size);
                  } else {
459 460 461
                    auto& feature_value = local_shard[key];
                    feature_value.resize(data_size);
                    float* data_ptr = feature_value.data();
462
                    _value_accesor->Create(&data_buffer_ptr, 1);
Z
zhaocaibei123 已提交
463 464 465 466
                    memcpy(data_ptr, data_buffer_ptr,
                           data_size * sizeof(float));
                  }
                } else {
467 468
                  data_size = itr.value().size();
                  memcpy(data_buffer_ptr, itr.value().data(),
Z
zhaocaibei123 已提交
469 470
                         data_size * sizeof(float));
                }
471
                for (size_t mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
Z
zhaocaibei123 已提交
472 473 474 475
                  data_buffer[mf_idx] = 0.0;
                }
                auto offset = keys[i].second;
                float* select_data = pull_values + select_value_size * offset;
476
                _value_accesor->Select(&select_data,
Z
zhaocaibei123 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489
                                       (const float**)&data_buffer_ptr, 1);
              }

              return 0;
            });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
490 491
int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
                                         const uint64_t* keys, size_t num) {
492
  CostTimer timer("pscore_sparse_select_all");
493 494 495
  size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
496 497 498 499 500 501 502 503 504

  std::vector<std::future<int>> tasks(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
      _real_local_shard_num);
  for (size_t i = 0; i < num; ++i) {
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
    task_keys[shard_id].push_back({keys[i], i});
  }
  // std::atomic<uint32_t> missed_keys{0};
505
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
506 507 508 509 510 511 512 513
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &task_keys, pull_values, value_size,
             mf_value_size]() -> int {
              auto& keys = task_keys[shard_id];
              auto& local_shard = _local_shards[shard_id];
              float data_buffer[value_size];
              float* data_buffer_ptr = data_buffer;
514
              for (size_t i = 0; i < keys.size(); ++i) {
515 516 517 518 519 520 521 522 523
                uint64_t key = keys[i].first;
                auto itr = local_shard.find(key);
                size_t data_size = value_size - mf_value_size;
                FixedFeatureValue* ret = NULL;
                if (itr == local_shard.end()) {
                  // ++missed_keys;
                  auto& feature_value = local_shard[key];
                  feature_value.resize(data_size);
                  float* data_ptr = feature_value.data();
524
                  _value_accesor->Create(&data_buffer_ptr, 1);
525 526 527 528 529 530 531 532 533 534 535 536 537 538
                  memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
                  ret = &feature_value;
                } else {
                  ret = itr.value_ptr();
                }
                int pull_data_idx = keys[i].second;
                pull_values[pull_data_idx] = (char*)ret;
              }
              return 0;
            });
  }
  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
Z
zhaocaibei123 已提交
539 540 541
  return 0;
}

Z
zhaocaibei123 已提交
542 543
int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
                                      size_t num) {
544 545
  CostTimer timer("pserver_sparse_update_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
546
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
547
      _real_local_shard_num);
Z
zhaocaibei123 已提交
548
  for (size_t i = 0; i < num; ++i) {
549
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
550 551 552
    task_keys[shard_id].push_back({keys[i], i});
  }

553 554 555 556
  const size_t value_col =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
557
  size_t update_value_col =
558
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
559

560
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
561
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
Z
zhaocaibei123 已提交
562 563 564
        [this, shard_id, value_col, mf_value_col, update_value_col, values,
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
565
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
566 567
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
568
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
569 570 571 572
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data =
                values + push_data_idx * update_value_col;
573 574 575
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
576
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
577 578 579
                continue;
              }
              auto value_size = value_col - mf_value_col;
580 581
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
582
              _value_accesor->Create(&data_buffer_ptr, 1);
583
              memcpy(feature_value.data(), data_buffer_ptr,
Z
zhaocaibei123 已提交
584
                     value_size * sizeof(float));
585
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
586 587
            }

588 589 590
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
591 592

            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
593
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
594 595 596
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
597
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
Z
zhaocaibei123 已提交
598

599
              if (_value_accesor->NeedExtendMF(data_buffer)) {
600 601
                feature_value.resize(value_col);
                value_data = feature_value.data();
602
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
603 604 605 606 607 608 609 610 611 612 613 614 615 616
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
617 618
int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
                                      const float** values, size_t num) {
619
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
620
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
621
      _real_local_shard_num);
Z
zhaocaibei123 已提交
622
  for (size_t i = 0; i < num; ++i) {
623
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
624 625 626
    task_keys[shard_id].push_back({keys[i], i});
  }

627 628 629
  size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
630
  size_t update_value_col =
631
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
632

633 634
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
Z
zhaocaibei123 已提交
635 636 637
        [this, shard_id, value_col, mf_value_col, update_value_col, values,
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
638
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
639 640
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
641
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
642 643 644
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data = values[push_data_idx];
645 646 647
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
648
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
649 650 651
                continue;
              }
              auto value_size = value_col - mf_value_col;
652 653
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
654
              _value_accesor->Create(&data_buffer_ptr, 1);
655
              memcpy(feature_value.data(), data_buffer_ptr,
Z
zhaocaibei123 已提交
656
                     value_size * sizeof(float));
657
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
658
            }
659 660 661
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
662
            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
663
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
664 665 666
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
667 668
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
              if (_value_accesor->NeedExtendMF(data_buffer)) {
669 670
                feature_value.resize(value_col);
                value_data = feature_value.data();
671
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
672 673 674 675 676 677 678 679 680 681 682 683 684 685
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
686
int32_t MemorySparseTable::Flush() { return 0; }
Z
zhaocaibei123 已提交
687

Z
zhaocaibei123 已提交
688 689
int32_t MemorySparseTable::Shrink(const std::string& param) {
  VLOG(0) << "MemorySparseTable::Shrink";
Z
zhaocaibei123 已提交
690
  // TODO(zhaocaibei123): implement with multi-thread
691
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
692
    // Shrink
693 694
    auto& shard = _local_shards[shard_id];
    for (auto it = shard.begin(); it != shard.end();) {
695
      if (_value_accesor->Shrink(it.value().data())) {
696 697 698
        it = shard.erase(it);
      } else {
        ++it;
Z
zhaocaibei123 已提交
699 700 701 702 703 704
      }
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
705
void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
Z
zhaocaibei123 已提交
706 707 708

}  // namespace distributed
}  // namespace paddle