memory_sparse_table.cc 27.1 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <omp.h>
Z
zhaocaibei123 已提交
16 17
#include <sstream>

18
#include "paddle/fluid/distributed/common/cost_timer.h"
19
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
Z
zhaocaibei123 已提交
20 21 22 23 24 25 26 27 28 29
#include "paddle/fluid/framework/io/fs.h"

#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace distributed {

// TODO(zhaocaibei123): configure
30
bool FLAGS_pserver_create_value_when_push = true;
31 32
int FLAGS_pserver_table_save_max_retry = 3;
bool FLAGS_pserver_enable_create_feasign_randomly = false;
Z
zhaocaibei123 已提交
33 34

int32_t MemorySparseTable::initialize() {
35 36 37
  _shards_task_pool.resize(_task_pool_size);
  for (int i = 0; i < _shards_task_pool.size(); ++i) {
    _shards_task_pool[i].reset(new ::ThreadPool(1));
Z
zhaocaibei123 已提交
38
  }
39 40 41
  auto& profiler = CostProfiler::instance();
  profiler.register_profiler("pserver_sparse_update_all");
  profiler.register_profiler("pserver_sparse_select_all");
Z
zhaocaibei123 已提交
42 43 44 45 46 47
  initialize_value();
  VLOG(0) << "initalize MemorySparseTable succ";
  return 0;
}

int32_t MemorySparseTable::initialize_value() {
48 49 50 51 52 53 54 55 56
  _sparse_table_shard_num = static_cast<int>(_config.shard_num());
  _avg_local_shard_num =
      SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
  _real_local_shard_num = _avg_local_shard_num;
  if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) {
    _real_local_shard_num =
        _sparse_table_shard_num - _real_local_shard_num * _shard_idx;
    _real_local_shard_num =
        _real_local_shard_num < 0 ? 0 : _real_local_shard_num;
Z
zhaocaibei123 已提交
57
  }
58 59 60
  VLOG(1) << "memory sparse table _avg_local_shard_num: "
          << _avg_local_shard_num
          << " _real_local_shard_num: " << _real_local_shard_num;
Z
zhaocaibei123 已提交
61

62
  _local_shards.reset(new shard_type[_real_local_shard_num]);
Z
zhaocaibei123 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

  return 0;
}

int32_t MemorySparseTable::load(const std::string& path,
                                const std::string& param) {
  std::string table_path = table_dir(path);
  auto file_list = _afs_client.list(table_path);

  std::sort(file_list.begin(), file_list.end());
  for (auto file : file_list) {
    VLOG(1) << "MemorySparseTable::load() file list: " << file;
  }

  int load_param = atoi(param.c_str());
78
  auto expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
79 80 81 82 83 84 85 86 87 88
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

89
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
90

91 92
  size_t feature_value_size =
      _value_accesor->GetTableInfo(SIZE) / sizeof(float);
93 94 95 96 97

  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
  for (size_t i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
98 99 100 101
    FsChannelConfig channel_config;
    channel_config.path = file_list[file_start_idx + i];
    VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
            << " into local shard " << i;
102
    channel_config.converter = _value_accesor->Converter(load_param).converter;
Z
zhaocaibei123 已提交
103
    channel_config.deconverter =
104
        _value_accesor->Converter(load_param).deconverter;
Z
zhaocaibei123 已提交
105 106 107 108 109 110 111 112 113 114

    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
      char* end = NULL;
115
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
116 117 118 119
      try {
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
120 121
          auto& value = shard[key];
          value.resize(feature_value_size);
122
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
123
          value.resize(parse_size);
Z
zhaocaibei123 已提交
124 125 126 127

          // for debug
          for (int ii = 0; ii < parse_size; ++ii) {
            VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
128
                    << ": " << value.data()[ii] << " local_shard: " << i;
Z
zhaocaibei123 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
          }
        }
        read_channel->close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << channel_config.path << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << channel_config.path << " , retry_num=" << retry_num;
      }
145
      if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
146 147 148 149 150 151 152
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
153
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
154 155 156 157 158 159 160 161 162
  return 0;
}

int32_t MemorySparseTable::load_local_fs(const std::string& path,
                                         const std::string& param) {
  std::string table_path = table_dir(path);
  auto file_list = paddle::framework::localfs_list(table_path);

  int load_param = atoi(param.c_str());
163
  auto expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
164 165 166 167 168 169 170 171 172 173
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

174
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
175

176 177
  size_t feature_value_size =
      _value_accesor->GetTableInfo(SIZE) / sizeof(float);
Z
zhaocaibei123 已提交
178

179 180 181 182
  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
  for (size_t i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
183 184 185 186 187 188 189 190 191
    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      std::ifstream file(file_list[file_start_idx + i]);
      char* end = NULL;
192
      auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
193 194 195
      try {
        while (std::getline(file, line_data) && line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
196 197
          auto& value = shard[key];
          value.resize(feature_value_size);
198
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
199
          value.resize(parse_size);
Z
zhaocaibei123 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        }
        file.close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << file_list[file_start_idx + i] << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << file_list[file_start_idx + i]
                   << " , retry_num=" << retry_num;
      }
216
      if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
217 218 219 220 221 222 223
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
224
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237
  return 0;
}

int32_t MemorySparseTable::save(const std::string& dirname,
                                const std::string& param) {
  VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
  std::string table_path = table_dir(dirname);
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  std::atomic<uint32_t> feasign_size_all{0};

238
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;
Z
zhaocaibei123 已提交
239

240 241 242 243
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
  for (size_t i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
244 245 246 247 248 249 250 251 252 253
    FsChannelConfig channel_config;
    if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
      channel_config.path = paddle::string::format_string(
          "%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx,
          file_start_idx + i);
    } else {
      channel_config.path =
          paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(),
                                        _shard_idx, file_start_idx + i);
    }
254
    channel_config.converter = _value_accesor->Converter(save_param).converter;
Z
zhaocaibei123 已提交
255
    channel_config.deconverter =
256
        _value_accesor->Converter(save_param).deconverter;
Z
zhaocaibei123 已提交
257 258 259 260
    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
261
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
262 263 264 265 266 267
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
268
      for (auto it = shard.begin(); it != shard.end(); ++it) {
269 270
        if (_value_accesor->Save(it.value().data(), save_param)) {
          std::string format_value = _value_accesor->ParseToString(
271 272 273 274 275 276 277 278 279 280
              it.value().data(), it.value().size());
          if (0 !=
              write_channel->write_line(paddle::string::format_string(
                  "%lu %s", it.key(), format_value.c_str()))) {
            ++retry_num;
            is_write_failed = true;
            LOG(ERROR)
                << "MemorySparseTable save prefix failed, retry it! path:"
                << channel_config.path << " , retry_num=" << retry_num;
            break;
Z
zhaocaibei123 已提交
281
          }
282
          ++feasign_size;
Z
zhaocaibei123 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295
        }
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save prefix failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
296
      if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
297 298 299 300 301
        LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
302
    for (auto it = shard.begin(); it != shard.end(); ++it) {
303
      _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
Z
zhaocaibei123 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
    }
    LOG(INFO) << "MemorySparseTable save prefix success, path: "
              << channel_config.path;
  }
  // int32 may overflow need to change return value
  return 0;
}

int32_t MemorySparseTable::save_local_fs(const std::string& dirname,
                                         const std::string& param,
                                         const std::string& prefix) {
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
  std::string table_path = table_dir(dirname);
  int feasign_cnt = 0;
319 320 321 322 323 324 325 326
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;

  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
  std::atomic<uint32_t> feasign_size_all{0};

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
  for (size_t i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
327
    feasign_cnt = 0;
328
    auto& shard = _local_shards[i];
Z
zhaocaibei123 已提交
329 330 331 332 333
    std::string file_name = paddle::string::format_string(
        "%s/part-%s-%03d-%05d", table_path.c_str(), prefix.c_str(), _shard_idx,
        file_start_idx + i);
    std::ofstream os;
    os.open(file_name);
334
    for (auto it = shard.begin(); it != shard.end(); ++it) {
335 336 337
      if (_value_accesor->Save(it.value().data(), save_param)) {
        std::string format_value =
            _value_accesor->ParseToString(it.value().data(), it.value().size());
338 339 340 341 342
        std::string out_line = paddle::string::format_string(
            "%lu %s\n", it.key(), format_value.c_str());
        // VLOG(2) << out_line.c_str();
        os.write(out_line.c_str(), sizeof(char) * out_line.size());
        ++feasign_cnt;
Z
zhaocaibei123 已提交
343 344 345 346 347 348 349 350 351
      }
    }
    os.close();
    LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
              << "feasign_cnt: " << feasign_cnt;
  }
  return 0;
}

352 353 354 355 356 357 358
int64_t MemorySparseTable::local_size() {
  int64_t local_size = 0;
  for (size_t i = 0; i < _real_local_shard_num; ++i) {
    local_size += _local_shards[i].size();
  }
  return local_size;
}
Z
zhaocaibei123 已提交
359

360 361 362 363 364 365 366 367 368 369 370
int64_t MemorySparseTable::local_mf_size() {
  std::vector<int64_t> size_arr(_real_local_shard_num, 0);
  std::vector<std::future<int>> tasks(_real_local_shard_num);
  int64_t ret_size = 0;
  for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &size_arr]() -> int {
              auto& local_shard = _local_shards[shard_id];
              for (auto it = local_shard.begin(); it != local_shard.end();
                   ++it) {
371
                if (_value_accesor->HasMF(it.value().size())) {
372 373 374 375 376 377 378 379 380 381 382
                  size_arr[shard_id] += 1;
                }
              }
              return 0;
            });
  }
  for (size_t i = 0; i < _real_local_shard_num; ++i) {
    tasks[i].wait();
  }
  for (auto x : size_arr) {
    ret_size += x;
Z
zhaocaibei123 已提交
383
  }
384 385
  return ret_size;
}
Z
zhaocaibei123 已提交
386

387 388 389
std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
  int64_t feasign_size = local_size();
  int64_t mf_size = local_mf_size();
Z
zhaocaibei123 已提交
390 391 392
  return {feasign_size, mf_size};
}

Y
yaoxuefeng 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
int32_t MemorySparseTable::Pull(TableContext& context) {
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
    char** pull_values = context.pull_context.ptr_values;
    const uint64_t* keys = context.pull_context.keys;
    return pull_sparse_ptr(pull_values, keys, context.num);
  } else {
    float* pull_values = context.pull_context.values;
    const PullSparseValue& pull_value = context.pull_context.pull_value;
    return pull_sparse(pull_values, pull_value);
  }
}

int32_t MemorySparseTable::Push(TableContext& context) {
  CHECK(context.value_type == Sparse);

  const uint64_t* keys = context.push_context.keys;
410
  return push_sparse(keys, context.push_context.values, context.num);
Y
yaoxuefeng 已提交
411 412
}

Z
zhaocaibei123 已提交
413 414
int32_t MemorySparseTable::pull_sparse(float* pull_values,
                                       const PullSparseValue& pull_value) {
415 416
  CostTimer timer("pserver_sparse_select_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
417

418 419 420 421
  const size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
  size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
  size_t select_value_size =
      _value_accesor->GetTableInfo(SELECT_SIZE) / sizeof(float);
Z
zhaocaibei123 已提交
422 423 424
  // std::atomic<uint32_t> missed_keys{0};

  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
425
      _real_local_shard_num);
Z
zhaocaibei123 已提交
426 427
  size_t num = pull_value.numel_;
  for (size_t i = 0; i < num; ++i) {
428 429
    int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) %
                   _avg_local_shard_num;
Z
zhaocaibei123 已提交
430 431
    task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
  }
432
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
433
    tasks[shard_id] =
434
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
Z
zhaocaibei123 已提交
435 436
            [this, shard_id, &task_keys, value_size, pull_values, mf_value_size,
             select_value_size]() -> int {
437
              auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
438 439 440 441 442 443
              float data_buffer[value_size];  // NOLINT
              float* data_buffer_ptr = data_buffer;

              auto& keys = task_keys[shard_id];
              for (size_t i = 0; i < keys.size(); i++) {
                uint64_t key = keys[i].first;
444
                auto itr = local_shard.find(key);
Z
zhaocaibei123 已提交
445
                size_t data_size = value_size - mf_value_size;
446
                if (itr == local_shard.end()) {
Z
zhaocaibei123 已提交
447
                  // ++missed_keys;
448
                  if (FLAGS_pserver_create_value_when_push) {
Z
zhaocaibei123 已提交
449 450
                    memset(data_buffer, 0, sizeof(float) * data_size);
                  } else {
451 452 453
                    auto& feature_value = local_shard[key];
                    feature_value.resize(data_size);
                    float* data_ptr = feature_value.data();
454
                    _value_accesor->Create(&data_buffer_ptr, 1);
Z
zhaocaibei123 已提交
455 456 457 458
                    memcpy(data_ptr, data_buffer_ptr,
                           data_size * sizeof(float));
                  }
                } else {
459 460
                  data_size = itr.value().size();
                  memcpy(data_buffer_ptr, itr.value().data(),
Z
zhaocaibei123 已提交
461 462 463 464 465 466 467
                         data_size * sizeof(float));
                }
                for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
                  data_buffer[mf_idx] = 0.0;
                }
                auto offset = keys[i].second;
                float* select_data = pull_values + select_value_size * offset;
468
                _value_accesor->Select(&select_data,
Z
zhaocaibei123 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
                                       (const float**)&data_buffer_ptr, 1);
              }

              return 0;
            });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
                                           const uint64_t* keys, size_t num) {
484
  CostTimer timer("pscore_sparse_select_all");
485 486
  size_t value_size = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
  size_t mf_value_size = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514

  std::vector<std::future<int>> tasks(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
      _real_local_shard_num);
  for (size_t i = 0; i < num; ++i) {
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
    task_keys[shard_id].push_back({keys[i], i});
  }
  // std::atomic<uint32_t> missed_keys{0};
  for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &task_keys, pull_values, value_size,
             mf_value_size]() -> int {
              auto& keys = task_keys[shard_id];
              auto& local_shard = _local_shards[shard_id];
              float data_buffer[value_size];
              float* data_buffer_ptr = data_buffer;
              for (int i = 0; i < keys.size(); ++i) {
                uint64_t key = keys[i].first;
                auto itr = local_shard.find(key);
                size_t data_size = value_size - mf_value_size;
                FixedFeatureValue* ret = NULL;
                if (itr == local_shard.end()) {
                  // ++missed_keys;
                  auto& feature_value = local_shard[key];
                  feature_value.resize(data_size);
                  float* data_ptr = feature_value.data();
515
                  _value_accesor->Create(&data_buffer_ptr, 1);
516 517 518 519 520 521 522 523 524 525 526 527 528 529
                  memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
                  ret = &feature_value;
                } else {
                  ret = itr.value_ptr();
                }
                int pull_data_idx = keys[i].second;
                pull_values[pull_data_idx] = (char*)ret;
              }
              return 0;
            });
  }
  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
Z
zhaocaibei123 已提交
530 531 532 533 534
  return 0;
}

int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
                                       const float* values, size_t num) {
535 536
  CostTimer timer("pserver_sparse_update_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
537
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
538
      _real_local_shard_num);
Z
zhaocaibei123 已提交
539
  for (size_t i = 0; i < num; ++i) {
540
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
541 542 543
    task_keys[shard_id].push_back({keys[i], i});
  }

544 545 546 547
  const size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
  size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
  size_t update_value_col =
      _value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
Z
zhaocaibei123 已提交
548

549 550
  for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
Z
zhaocaibei123 已提交
551 552 553
        [this, shard_id, value_col, mf_value_col, update_value_col, values,
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
554
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
555 556 557 558 559 560 561
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
          for (int i = 0; i < keys.size(); ++i) {
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data =
                values + push_data_idx * update_value_col;
562 563 564
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
565
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
566 567 568
                continue;
              }
              auto value_size = value_col - mf_value_col;
569 570
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
571
              _value_accesor->Create(&data_buffer_ptr, 1);
572
              memcpy(feature_value.data(), data_buffer_ptr,
Z
zhaocaibei123 已提交
573
                     value_size * sizeof(float));
574
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
575 576
            }

577 578 579
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
580 581

            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
582
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
583 584 585
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
586
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
Z
zhaocaibei123 已提交
587

588
              if (_value_accesor->NeedExtendMF(data_buffer)) {
589 590
                feature_value.resize(value_col);
                value_data = feature_value.data();
591
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
                                       const float** values, size_t num) {
  _push_sparse(keys, values, num);
  return 0;
}

int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
                                        const float** values, size_t num) {
614
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
615
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
616
      _real_local_shard_num);
Z
zhaocaibei123 已提交
617
  for (size_t i = 0; i < num; ++i) {
618
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
619 620 621
    task_keys[shard_id].push_back({keys[i], i});
  }

622 623 624 625
  size_t value_col = _value_accesor->GetTableInfo(SIZE) / sizeof(float);
  size_t mf_value_col = _value_accesor->GetTableInfo(MF_SIZE) / sizeof(float);
  size_t update_value_col =
      _value_accesor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
Z
zhaocaibei123 已提交
626

627 628
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
Z
zhaocaibei123 已提交
629 630 631
        [this, shard_id, value_col, mf_value_col, update_value_col, values,
         &task_keys]() -> int {
          auto& keys = task_keys[shard_id];
632
          auto& local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
633 634 635 636 637 638
          float data_buffer[value_col];  // NOLINT
          float* data_buffer_ptr = data_buffer;
          for (int i = 0; i < keys.size(); ++i) {
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
            const float* update_data = values[push_data_idx];
639 640 641
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
642
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
643 644 645
                continue;
              }
              auto value_size = value_col - mf_value_col;
646 647
              auto& feature_value = local_shard[key];
              feature_value.resize(value_size);
648
              _value_accesor->Create(&data_buffer_ptr, 1);
649
              memcpy(feature_value.data(), data_buffer_ptr,
Z
zhaocaibei123 已提交
650
                     value_size * sizeof(float));
651
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
652
            }
653 654 655
            auto& feature_value = itr.value();
            float* value_data = feature_value.data();
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
656
            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
657
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
658 659 660
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
661 662
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
              if (_value_accesor->NeedExtendMF(data_buffer)) {
663 664
                feature_value.resize(value_col);
                value_data = feature_value.data();
665
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

int32_t MemorySparseTable::flush() { return 0; }

int32_t MemorySparseTable::shrink(const std::string& param) {
  VLOG(0) << "MemorySparseTable::shrink";
  // TODO(zhaocaibei123): implement with multi-thread
685
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
686
    // shrink
687 688
    auto& shard = _local_shards[shard_id];
    for (auto it = shard.begin(); it != shard.end();) {
689
      if (_value_accesor->Shrink(it.value().data())) {
690 691 692
        it = shard.erase(it);
      } else {
        ++it;
Z
zhaocaibei123 已提交
693 694 695 696 697 698 699 700 701 702
      }
    }
  }
  return 0;
}

void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; }

}  // namespace distributed
}  // namespace paddle