memory_sparse_table.cc 40.7 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <omp.h>
16
#include <sstream>
Z
zhaocaibei123 已提交
17 18

#include "glog/logging.h"
19
#include "paddle/fluid/distributed/common/cost_timer.h"
Z
zhaocaibei123 已提交
20 21 22 23
#include "paddle/fluid/distributed/common/local_random.h"
#include "paddle/fluid/distributed/common/topk_calculator.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/framework/archive.h"
24
#include "paddle/fluid/framework/io/fs.h"
Z
zhaocaibei123 已提交
25 26

// #include "boost/lexical_cast.hpp"
Z
zhaocaibei123 已提交
27 28
#include "paddle/fluid/platform/enforce.h"

29 30
DEFINE_bool(pserver_print_missed_key_num_every_push,
            false,
Z
zhaocaibei123 已提交
31
            "pserver_print_missed_key_num_every_push");
32 33
DEFINE_bool(pserver_create_value_when_push,
            true,
Z
zhaocaibei123 已提交
34
            "pserver create value when push");
35 36
DEFINE_bool(pserver_enable_create_feasign_randomly,
            false,
Z
zhaocaibei123 已提交
37 38 39
            "pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");

Z
zhaocaibei123 已提交
40 41 42
namespace paddle {
namespace distributed {

Z
zhaocaibei123 已提交
43
int32_t MemorySparseTable::Initialize() {
44
  auto &profiler = CostProfiler::instance();
45 46
  profiler.register_profiler("pserver_sparse_update_all");
  profiler.register_profiler("pserver_sparse_select_all");
Z
zhaocaibei123 已提交
47
  InitializeValue();
D
danleifeng 已提交
48
  _shards_task_pool.resize(_task_pool_size);
49
  for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
D
danleifeng 已提交
50 51
    _shards_task_pool[i].reset(new ::ThreadPool(1));
  }
Z
zhaocaibei123 已提交
52 53 54 55
  VLOG(0) << "initalize MemorySparseTable succ";
  return 0;
}

Z
zhaocaibei123 已提交
56
int32_t MemorySparseTable::InitializeValue() {
57 58
  _sparse_table_shard_num = static_cast<int>(_config.shard_num());
  _avg_local_shard_num =
59
      sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
60
  _real_local_shard_num = _avg_local_shard_num;
Z
zhangchunle 已提交
61 62
  if (static_cast<int>(_real_local_shard_num * (_shard_idx + 1)) >
      _sparse_table_shard_num) {
63 64 65 66
    _real_local_shard_num =
        _sparse_table_shard_num - _real_local_shard_num * _shard_idx;
    _real_local_shard_num =
        _real_local_shard_num < 0 ? 0 : _real_local_shard_num;
Z
zhaocaibei123 已提交
67
  }
D
danleifeng 已提交
68 69 70
#ifdef PADDLE_WITH_HETERPS
  _task_pool_size = _sparse_table_shard_num;
#endif
71 72
  VLOG(1) << "memory sparse table _avg_local_shard_num: "
          << _avg_local_shard_num
D
danleifeng 已提交
73 74
          << " _real_local_shard_num: " << _real_local_shard_num
          << " _task_pool_size:" << _task_pool_size;
Z
zhaocaibei123 已提交
75

76
  _local_shards.reset(new shard_type[_real_local_shard_num]);
Z
zhaocaibei123 已提交
77

Z
zhaocaibei123 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  if (_config.enable_revert()) {
    // calculate merged shard number based on config param;
    _shard_merge_rate = _config.has_shard_merge_rate()
                            ? _config.shard_merge_rate()
                            : _shard_merge_rate;
    CHECK((_m_avg_local_shard_num = static_cast<int>(
               std::ceil(_avg_local_shard_num * _shard_merge_rate)),
           _m_avg_local_shard_num <= _avg_local_shard_num));
    CHECK((_m_real_local_shard_num = static_cast<int>(
               std::ceil(_real_local_shard_num * _shard_merge_rate)),
           _m_real_local_shard_num <= _real_local_shard_num));

    uint32_t avg_shard_server_num =
        _sparse_table_shard_num / _avg_local_shard_num;
    uint32_t last_server_shard_num =
        _sparse_table_shard_num - avg_shard_server_num * _avg_local_shard_num;
    _m_sparse_table_shard_num =
        avg_shard_server_num * _m_avg_local_shard_num +
        std::ceil(last_server_shard_num * _shard_merge_rate);
    LOG(INFO) << "merged shard info: [" << _m_sparse_table_shard_num << "|"
              << _m_avg_local_shard_num << "|" << _m_real_local_shard_num
              << "]";
    _local_shards_new.reset(new shard_type[_real_local_shard_num]);
  }
Z
zhaocaibei123 已提交
102 103 104
  return 0;
}

105 106
int32_t MemorySparseTable::Load(const std::string &path,
                                const std::string &param) {
Z
zhaocaibei123 已提交
107
  std::string table_path = TableDir(path);
Z
zhaocaibei123 已提交
108 109 110 111
  auto file_list = _afs_client.list(table_path);

  std::sort(file_list.begin(), file_list.end());
  for (auto file : file_list) {
Z
zhaocaibei123 已提交
112
    VLOG(1) << "MemorySparseTable::Load() file list: " << file;
Z
zhaocaibei123 已提交
113 114 115
  }

  int load_param = atoi(param.c_str());
116
  size_t expect_shard_num = _sparse_table_shard_num;
Z
zhaocaibei123 已提交
117 118 119 120 121 122 123 124 125 126
  if (file_list.size() != expect_shard_num) {
    LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
                 << " not equal to expect_shard_num:" << expect_shard_num;
    return -1;
  }
  if (file_list.size() == 0) {
    LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
    return -1;
  }

Z
zhaocaibei123 已提交
127 128 129 130
  if (load_param == 5) {
    return LoadPatch(file_list, load_param);
  }

131
  size_t file_start_idx = _shard_idx * _avg_local_shard_num;
Z
zhaocaibei123 已提交
132

Z
zhaocaibei123 已提交
133 134 135 136
  if (file_start_idx >= file_list.size()) {
    return 0;
  }

137
  size_t feature_value_size =
138
      _value_accesor->GetAccessorInfo().size / sizeof(float);
139 140 141 142

  int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
143
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
144 145 146 147
    FsChannelConfig channel_config;
    channel_config.path = file_list[file_start_idx + i];
    VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
            << " into local shard " << i;
148
    channel_config.converter = _value_accesor->Converter(load_param).converter;
Z
zhaocaibei123 已提交
149
    channel_config.deconverter =
150
        _value_accesor->Converter(load_param).deconverter;
Z
zhaocaibei123 已提交
151 152 153 154 155 156 157 158 159

    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
160 161
      char *end = NULL;
      auto &shard = _local_shards[i];
Z
zhaocaibei123 已提交
162 163 164 165
      try {
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
166
          auto &value = shard[key];
167
          value.resize(feature_value_size);
168
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
169
          value.resize(parse_size);
Z
zhaocaibei123 已提交
170 171 172 173

          // for debug
          for (int ii = 0; ii < parse_size; ++ii) {
            VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
174
                    << ": " << value.data()[ii] << " local_shard: " << i;
Z
zhaocaibei123 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
          }
        }
        read_channel->close();
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
              << channel_config.path << " , retry_num=" << retry_num;
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
                   << channel_config.path << " , retry_num=" << retry_num;
      }
Z
zhaocaibei123 已提交
191
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
192 193 194 195 196 197 198
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
            << file_list[file_start_idx] << " to "
199
            << file_list[file_start_idx + _real_local_shard_num - 1];
Z
zhaocaibei123 已提交
200 201 202
  return 0;
}

203
int32_t MemorySparseTable::LoadPatch(const std::vector<std::string> &file_list,
Z
zhaocaibei123 已提交
204 205 206 207
                                     int load_param) {
  if (!_config.enable_revert()) {
    LOG(INFO) << "MemorySparseTable should be enabled revert.";
    return 0;
Z
zhaocaibei123 已提交
208
  }
Z
zhaocaibei123 已提交
209 210 211 212 213 214 215
  // 聚合分片数据索引
  int start_idx = _shard_idx * _m_avg_local_shard_num;
  int end_idx = start_idx + _m_real_local_shard_num;
  // 原始分片数据索引
  int o_start_idx = _shard_idx * _avg_local_shard_num;
  int o_end_idx = o_start_idx + _real_local_shard_num;

216
  if (start_idx >= static_cast<int>(file_list.size())) {
Z
zhaocaibei123 已提交
217
    return 0;
Z
zhaocaibei123 已提交
218
  }
219
  size_t feature_value_size =
220
      _value_accesor->GetAccessorInfo().size / sizeof(float);
Z
zhaocaibei123 已提交
221 222 223
  end_idx =
      end_idx < _m_sparse_table_shard_num ? end_idx : _m_sparse_table_shard_num;
  int thread_num = (end_idx - start_idx) < 15 ? (end_idx - start_idx) : 15;
Z
zhaocaibei123 已提交
224

225 226
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
227
  for (int i = start_idx; i < end_idx; ++i) {
Z
zhaocaibei123 已提交
228 229 230 231 232 233
    FsChannelConfig channel_config;
    channel_config.path = file_list[i];
    channel_config.converter = _value_accesor->Converter(load_param).converter;
    channel_config.deconverter =
        _value_accesor->Converter(load_param).deconverter;

Z
zhaocaibei123 已提交
234 235 236 237 238 239 240
    bool is_read_failed = false;
    int retry_num = 0;
    int err_no = 0;
    do {
      is_read_failed = false;
      err_no = 0;
      std::string line_data;
Z
zhaocaibei123 已提交
241
      auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
242
      char *end = NULL;
Z
zhaocaibei123 已提交
243 244 245
      int m_local_shard_id = i % _m_avg_local_shard_num;
      std::unordered_set<size_t> global_shard_idx;
      std::string global_shard_idx_str;
246
      for (int j = o_start_idx; j < o_end_idx; ++j) {
Z
zhaocaibei123 已提交
247 248 249 250 251 252
        if ((j % _avg_local_shard_num) % _m_real_local_shard_num ==
            m_local_shard_id) {
          global_shard_idx.insert(j);
          global_shard_idx_str.append(std::to_string(j)).append(",");
        }
      }
Z
zhaocaibei123 已提交
253
      try {
Z
zhaocaibei123 已提交
254 255
        while (read_channel->read_line(line_data) == 0 &&
               line_data.size() > 1) {
Z
zhaocaibei123 已提交
256
          uint64_t key = std::strtoul(line_data.data(), &end, 10);
Z
zhaocaibei123 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269

          auto index_iter =
              global_shard_idx.find(key % _sparse_table_shard_num);
          if (index_iter == global_shard_idx.end()) {
            LOG(WARNING) << "MemorySparseTable key:" << key
                         << " not match shard,"
                         << " file_idx:" << i
                         << " global_shard_idx:" << global_shard_idx_str
                         << " shard num:" << _sparse_table_shard_num
                         << " file:" << channel_config.path;
            continue;
          }
          size_t local_shard_idx = *index_iter % _avg_local_shard_num;
270
          auto &shard = _local_shards[local_shard_idx];
Z
zhaocaibei123 已提交
271

272
          auto &value = shard[key];
273
          value.resize(feature_value_size);
274
          int parse_size = _value_accesor->ParseFromString(++end, value.data());
275
          value.resize(parse_size);
Z
zhaocaibei123 已提交
276
        }
Z
zhaocaibei123 已提交
277
        read_channel->close();
Z
zhaocaibei123 已提交
278 279 280 281 282
        if (err_no == -1) {
          ++retry_num;
          is_read_failed = true;
          LOG(ERROR)
              << "MemorySparseTable load failed after read, retry it! path:"
Z
zhaocaibei123 已提交
283
              << channel_config.path << " , retry_num=" << retry_num;
Z
zhaocaibei123 已提交
284 285 286 287 288
        }
      } catch (...) {
        ++retry_num;
        is_read_failed = true;
        LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
Z
zhaocaibei123 已提交
289
                   << channel_config.path << " , retry_num=" << retry_num;
Z
zhaocaibei123 已提交
290
      }
Z
zhaocaibei123 已提交
291
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
292 293 294 295 296 297
        LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
        exit(-1);
      }
    } while (is_read_failed);
  }
  LOG(INFO) << "MemorySparseTable load success, path from "
Z
zhaocaibei123 已提交
298
            << file_list[start_idx] << " to " << file_list[end_idx - 1];
Z
zhaocaibei123 已提交
299 300 301
  return 0;
}

Z
zhaocaibei123 已提交
302
void MemorySparseTable::Revert() {
303
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
304 305 306 307 308 309 310 311
    _local_shards_new[i].clear();
  }
}

void MemorySparseTable::CheckSavePrePatchDone() {
  _save_patch_model_thread.join();
}

312 313
int32_t MemorySparseTable::Save(const std::string &dirname,
                                const std::string &param) {
Z
zhaocaibei123 已提交
314 315 316 317 318
  if (_real_local_shard_num == 0) {
    _local_show_threshold = -1;
    return 0;
  }

Z
zhaocaibei123 已提交
319 320 321
  VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
  int save_param =
      atoi(param.c_str());  // checkpoint:0  xbox delta:1  xbox base:2
Z
zhaocaibei123 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335

  // patch model
  if (save_param == 5) {
    _local_shards_patch_model.reset(_local_shards_new.release());
    _local_shards_new.reset(new shard_type[_real_local_shard_num]);
    _save_patch_model_thread = std::thread(std::bind(
        &MemorySparseTable::SavePatch, this, std::string(dirname), save_param));
    return 0;
  }

  // cache model
  int64_t tk_size = LocalSize() * _config.sparse_table_cache_rate();
  TopkCalculator tk(_real_local_shard_num, tk_size);

Z
zhaocaibei123 已提交
336
  std::string table_path = TableDir(dirname);
Z
zhaocaibei123 已提交
337 338 339 340
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  std::atomic<uint32_t> feasign_size_all{0};

341
  size_t file_start_idx = _avg_local_shard_num * _shard_idx;
Z
zhaocaibei123 已提交
342

D
danleifeng 已提交
343 344 345
#ifdef PADDLE_WITH_GPU_GRAPH
  int thread_num = _real_local_shard_num;
#else
346
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
D
danleifeng 已提交
347
#endif
348 349
  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
350
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
351 352 353
    FsChannelConfig channel_config;
    if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
      channel_config.path =
354 355 356 357 358 359 360 361 362
          paddle::string::format_string("%s/part-%03d-%05d.gz",
                                        table_path.c_str(),
                                        _shard_idx,
                                        file_start_idx + i);
    } else {
      channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
                                                          table_path.c_str(),
                                                          _shard_idx,
                                                          file_start_idx + i);
Z
zhaocaibei123 已提交
363
    }
364
    channel_config.converter = _value_accesor->Converter(save_param).converter;
Z
zhaocaibei123 已提交
365
    channel_config.deconverter =
366
        _value_accesor->Converter(save_param).deconverter;
Z
zhaocaibei123 已提交
367 368 369 370
    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
371
    auto &shard = _local_shards[i];
Z
zhaocaibei123 已提交
372 373 374 375 376 377
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
378
      for (auto it = shard.begin(); it != shard.end(); ++it) {
Z
zhaocaibei123 已提交
379 380 381 382 383 384 385
        if (_config.enable_sparse_table_cache() &&
            (save_param == 1 || save_param == 2) &&
            _value_accesor->Save(it.value().data(), 4)) {
          CostTimer timer10("sprase table top push");
          tk.push(i, _value_accesor->GetField(it.value().data(), "show"));
        }

386 387
        if (_value_accesor->Save(it.value().data(), save_param)) {
          std::string format_value = _value_accesor->ParseToString(
388
              it.value().data(), it.value().size());
389 390
          if (0 != write_channel->write_line(paddle::string::format_string(
                       "%lu %s", it.key(), format_value.c_str()))) {
391 392 393 394 395 396
            ++retry_num;
            is_write_failed = true;
            LOG(ERROR)
                << "MemorySparseTable save prefix failed, retry it! path:"
                << channel_config.path << " , retry_num=" << retry_num;
            break;
Z
zhaocaibei123 已提交
397
          }
398
          ++feasign_size;
Z
zhaocaibei123 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411
        }
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save prefix failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
Z
zhaocaibei123 已提交
412
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
Z
zhaocaibei123 已提交
413 414 415 416 417
        LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
418
    for (auto it = shard.begin(); it != shard.end(); ++it) {
419
      _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
Z
zhaocaibei123 已提交
420 421
    }
    LOG(INFO) << "MemorySparseTable save prefix success, path: "
Z
zhaocaibei123 已提交
422
              << channel_config.path << " feasign_size: " << feasign_size;
Z
zhaocaibei123 已提交
423
  }
Z
zhaocaibei123 已提交
424
  _local_show_threshold = tk.top();
Z
zhaocaibei123 已提交
425 426 427 428
  // int32 may overflow need to change return value
  return 0;
}

429
int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) {
Z
zhaocaibei123 已提交
430 431 432 433 434 435 436 437 438
  if (!_config.enable_revert()) {
    LOG(INFO) << "MemorySparseTable should be enabled revert.";
    return 0;
  }
  size_t file_start_idx = _m_avg_local_shard_num * _shard_idx;
  std::string table_path = TableDir(path);
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d-*", table_path.c_str(), _shard_idx));
  int thread_num = _m_real_local_shard_num < 20 ? _m_real_local_shard_num : 20;
439 440 441 442 443

  std::atomic<uint32_t> feasign_size_all{0};

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
444
  for (int i = 0; i < _m_real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    FsChannelConfig channel_config;
    channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
                                                        table_path.c_str(),
                                                        _shard_idx,
                                                        file_start_idx + i);

    channel_config.converter = _value_accesor->Converter(save_param).converter;
    channel_config.deconverter =
        _value_accesor->Converter(save_param).deconverter;

    bool is_write_failed = false;
    int feasign_size = 0;
    int retry_num = 0;
    int err_no = 0;
    do {
      err_no = 0;
      feasign_size = 0;
      is_write_failed = false;
      auto write_channel =
          _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);

466
      for (int j = 0; j < _real_local_shard_num; ++j) {
Z
zhaocaibei123 已提交
467
        if (j % _m_real_local_shard_num == i) {
468
          auto &shard = _local_shards_patch_model[j];
Z
zhaocaibei123 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
          for (auto it = shard.begin(); it != shard.end(); ++it) {
            if (_value_accesor->Save(it.value().data(), save_param)) {
              std::string format_value = _value_accesor->ParseToString(
                  it.value().data(), it.value().size());
              if (0 != write_channel->write_line(paddle::string::format_string(
                           "%lu %s", it.key(), format_value.c_str()))) {
                ++retry_num;
                is_write_failed = true;
                LOG(ERROR) << "MemorySparseTable save failed, retry it! path:"
                           << channel_config.path
                           << " , retry_num=" << retry_num;
                break;
              }
              ++feasign_size;
            }
          }
        }
        if (is_write_failed) break;
      }
      write_channel->close();
      if (err_no == -1) {
        ++retry_num;
        is_write_failed = true;
        LOG(ERROR)
            << "MemorySparseTable save patch failed after write, retry it! "
            << "path:" << channel_config.path << " , retry_num=" << retry_num;
      }
      if (is_write_failed) {
        _afs_client.remove(channel_config.path);
      }
      if (retry_num > FLAGS_pserver_table_save_max_retry) {
        LOG(ERROR) << "MemorySparseTable save patch failed reach max limit!";
        exit(-1);
      }
    } while (is_write_failed);
    feasign_size_all += feasign_size;
  }
  LOG(INFO) << "MemorySparseTable save patch success, path:"
            << paddle::string::format_string("%s/%03d/part-%03d-",
                                             path.c_str(),
                                             _config.table_id(),
                                             _shard_idx)
            << " from " << file_start_idx << " to "
            << file_start_idx + _m_real_local_shard_num - 1
            << ", feasign size: " << feasign_size_all;
  return 0;
}

int64_t MemorySparseTable::CacheShuffle(
518 519
    const std::string &path,
    const std::string &param,
Z
zhaocaibei123 已提交
520 521
    double cache_threshold,
    std::function<std::future<int32_t>(
522 523 524 525
        int msg_type, int to_pserver_id, std::string &msg)> send_msg_func,
    paddle::framework::Channel<std::pair<uint64_t, std::string>>
        &shuffled_channel,
    const std::vector<Table *> &table_ptrs) {
Z
zhaocaibei123 已提交
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
  LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold;
  int save_param = atoi(param.c_str());  // batch_model:0  xbox:1
  if (!_config.enable_sparse_table_cache() || cache_threshold < 0) {
    LOG(WARNING)
        << "cache shuffle failed not enable table cache or cache threshold < 0 "
        << _config.enable_sparse_table_cache() << " or " << cache_threshold;
    // return -1;
  }
  int shuffle_node_num = _config.sparse_table_cache_file_num();
  LOG(INFO) << "Table>> shuffle node num is: " << shuffle_node_num;
  // TODO(zhaocaibei123): check shuffle_node_num <= server_node_num
  int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;

  std::vector<
      paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>>
      writers(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, std::string>>> datas(
      _real_local_shard_num);

  int feasign_size = 0;
  std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>>
      tmp_channels;
548
  for (int i = 0; i < _real_local_shard_num; ++i) {
Z
zhaocaibei123 已提交
549 550 551 552 553 554
    tmp_channels.push_back(
        paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
  }

  omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
555 556
  for (int i = 0; i < _real_local_shard_num; ++i) {
    paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>> &writer =
Z
zhaocaibei123 已提交
557 558 559 560
        writers[i];
    writer.Reset(tmp_channels[i].get());

    for (size_t idx = 0; idx < table_ptrs.size(); idx++) {
561
      Table *table_ptr = table_ptrs[idx];
Z
zhaocaibei123 已提交
562
      auto value_accesor = table_ptr->ValueAccesor();
563
      shard_type *shard_ptr = static_cast<shard_type *>(table_ptr->GetShard(i));
Z
zhaocaibei123 已提交
564 565 566 567 568 569 570 571 572 573

      for (auto it = shard_ptr->begin(); it != shard_ptr->end(); ++it) {
        if (value_accesor->SaveCache(
                it.value().data(), save_param, cache_threshold)) {
          std::string format_value = value_accesor->ParseToString(
              it.value().data(), it.value().size());
          std::pair<uint64_t, std::string> pkv(it.key(), format_value.c_str());
          writer << pkv;
          ++feasign_size;
        }
Z
zhaocaibei123 已提交
574 575
      }
    }
Z
zhaocaibei123 已提交
576 577
    writer.Flush();
    writer.channel()->Close();
Z
zhaocaibei123 已提交
578
  }
Z
zhaocaibei123 已提交
579 580 581 582
  // LOG(INFO) << "MemorySparseTable cache KV save success to Channel feasigh
  // size: " << feasign_size << " and start sparse cache data shuffle real local
  // shard num: " << _real_local_shard_num;
  std::vector<std::pair<uint64_t, std::string>> local_datas;
583 584
  for (int idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
    paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>> &writer =
Z
zhaocaibei123 已提交
585 586
        writers[idx_shard];
    auto channel = writer.channel();
587
    std::vector<std::pair<uint64_t, std::string>> &data = datas[idx_shard];
Z
zhaocaibei123 已提交
588 589
    std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num);
    while (channel->Read(data)) {
590
      for (auto &t : data) {
Z
zhaocaibei123 已提交
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
        auto pserver_id =
            paddle::distributed::local_random_engine()() % shuffle_node_num;
        if (pserver_id != _shard_idx) {
          ars[pserver_id] << t;
        } else {
          local_datas.emplace_back(std::move(t));
        }
      }
      std::vector<std::future<int32_t>> total_status;
      std::vector<uint32_t> send_data_size(shuffle_node_num, 0);
      std::vector<int> send_index(shuffle_node_num);
      for (int i = 0; i < shuffle_node_num; ++i) {
        send_index[i] = i;
      }
      std::random_shuffle(send_index.begin(), send_index.end());
606
      for (int index = 0; index < shuffle_node_num; ++index) {
Z
zhaocaibei123 已提交
607
        int i = send_index[index];
608
        if (i == static_cast<int>(_shard_idx)) {
Z
zhaocaibei123 已提交
609 610 611 612 613 614 615 616 617 618
          continue;
        }
        if (ars[i].Length() == 0) {
          continue;
        }
        std::string msg(ars[i].Buffer(), ars[i].Length());
        auto ret = send_msg_func(101, i, msg);
        total_status.push_back(std::move(ret));
        send_data_size[i] += ars[i].Length();
      }
619
      for (auto &t : total_status) {
Z
zhaocaibei123 已提交
620 621 622 623 624 625 626 627
        t.wait();
      }
      ars.clear();
      ars = std::vector<paddle::framework::BinaryArchive>(shuffle_node_num);
      data = std::vector<std::pair<uint64_t, std::string>>();
    }
  }
  shuffled_channel->Write(std::move(local_datas));
Z
zhaocaibei123 已提交
628 629 630
  return 0;
}

Z
zhaocaibei123 已提交
631
int32_t MemorySparseTable::SaveCache(
632 633 634 635
    const std::string &path,
    const std::string &param,
    paddle::framework::Channel<std::pair<uint64_t, std::string>>
        &shuffled_channel) {
Z
zhaocaibei123 已提交
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
  if (_shard_idx >= _config.sparse_table_cache_file_num()) {
    return 0;
  }
  int save_param = atoi(param.c_str());  // batch_model:0  xbox:1
  std::string table_path = paddle::string::format_string(
      "%s/%03d_cache/", path.c_str(), _config.table_id());
  _afs_client.remove(paddle::string::format_string(
      "%s/part-%03d", table_path.c_str(), _shard_idx));
  uint32_t feasign_size = 0;
  FsChannelConfig channel_config;
  // not compress cache model
  channel_config.path = paddle::string::format_string(
      "%s/part-%03d", table_path.c_str(), _shard_idx);
  channel_config.converter = _value_accesor->Converter(save_param).converter;
  channel_config.deconverter =
      _value_accesor->Converter(save_param).deconverter;
  auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40);
  std::vector<std::pair<uint64_t, std::string>> data;
  bool is_write_failed = false;
  shuffled_channel->Close();
  while (shuffled_channel->Read(data)) {
657
    for (auto &t : data) {
Z
zhaocaibei123 已提交
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
      ++feasign_size;
      if (0 != write_channel->write_line(paddle::string::format_string(
                   "%lu %s", t.first, t.second.c_str()))) {
        LOG(ERROR) << "Cache Table save failed, "
                      "path:"
                   << channel_config.path << ", retry it!";
        is_write_failed = true;
        break;
      }
    }
    data = std::vector<std::pair<uint64_t, std::string>>();
  }
  if (is_write_failed) {
    _afs_client.remove(channel_config.path);
  }
  write_channel->close();
  LOG(INFO) << "MemorySparseTable cache save success, feasign: " << feasign_size
            << ", path: " << channel_config.path;
  shuffled_channel->Open();
  return feasign_size;
}

Z
zhaocaibei123 已提交
680
int64_t MemorySparseTable::LocalSize() {
681
  int64_t local_size = 0;
682
  for (int i = 0; i < _real_local_shard_num; ++i) {
683 684 685 686
    local_size += _local_shards[i].size();
  }
  return local_size;
}
Z
zhaocaibei123 已提交
687

Z
zhaocaibei123 已提交
688
int64_t MemorySparseTable::LocalMFSize() {
689 690 691
  std::vector<int64_t> size_arr(_real_local_shard_num, 0);
  std::vector<std::future<int>> tasks(_real_local_shard_num);
  int64_t ret_size = 0;
692
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
693 694 695
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
            [this, shard_id, &size_arr]() -> int {
696
              auto &local_shard = _local_shards[shard_id];
697 698
              for (auto it = local_shard.begin(); it != local_shard.end();
                   ++it) {
699
                if (_value_accesor->HasMF(it.value().size())) {
700 701 702 703 704 705
                  size_arr[shard_id] += 1;
                }
              }
              return 0;
            });
  }
706
  for (int i = 0; i < _real_local_shard_num; ++i) {
707 708 709 710
    tasks[i].wait();
  }
  for (auto x : size_arr) {
    ret_size += x;
Z
zhaocaibei123 已提交
711
  }
712 713
  return ret_size;
}
Z
zhaocaibei123 已提交
714

Z
zhaocaibei123 已提交
715 716 717
std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
  int64_t feasign_size = LocalSize();
  int64_t mf_size = LocalMFSize();
Z
zhaocaibei123 已提交
718 719 720
  return {feasign_size, mf_size};
}

721
int32_t MemorySparseTable::Pull(TableContext &context) {
Y
yaoxuefeng 已提交
722 723
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
724 725
    char **pull_values = context.pull_context.ptr_values;
    const uint64_t *keys = context.pull_context.keys;
Z
zhaocaibei123 已提交
726
    return PullSparsePtr(pull_values, keys, context.num);
Y
yaoxuefeng 已提交
727
  } else {
728 729
    float *pull_values = context.pull_context.values;
    const PullSparseValue &pull_value = context.pull_context.pull_value;
Z
zhaocaibei123 已提交
730
    return PullSparse(pull_values, pull_value);
Y
yaoxuefeng 已提交
731 732 733
  }
}

734
int32_t MemorySparseTable::Push(TableContext &context) {
Y
yaoxuefeng 已提交
735
  CHECK(context.value_type == Sparse);
736
  if (!context.use_ptr) {
737 738
    return PushSparse(
        context.push_context.keys, context.push_context.values, context.num);
739 740
  } else {
    return PushSparse(context.push_context.keys,
741 742
                      context.push_context.ptr_values,
                      context.num);
743
  }
Y
yaoxuefeng 已提交
744 745
}

746 747
int32_t MemorySparseTable::PullSparse(float *pull_values,
                                      const PullSparseValue &pull_value) {
748 749
  CostTimer timer("pserver_sparse_select_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
750

751 752 753 754
  const size_t value_size =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
755
  size_t select_value_size =
756
      _value_accesor->GetAccessorInfo().select_size / sizeof(float);
Z
zhaocaibei123 已提交
757 758 759
  // std::atomic<uint32_t> missed_keys{0};

  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
760
      _real_local_shard_num);
Z
zhaocaibei123 已提交
761 762
  size_t num = pull_value.numel_;
  for (size_t i = 0; i < num; ++i) {
763 764
    int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) %
                   _avg_local_shard_num;
Z
zhaocaibei123 已提交
765 766
    task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
  }
767
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
768
    tasks[shard_id] =
769
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
770 771 772 773 774 775
            [this,
             shard_id,
             &task_keys,
             value_size,
             pull_values,
             mf_value_size,
Z
zhaocaibei123 已提交
776
             select_value_size]() -> int {
777
              auto &local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
778
              float data_buffer[value_size];  // NOLINT
779
              float *data_buffer_ptr = data_buffer;
Z
zhaocaibei123 已提交
780

781
              auto &keys = task_keys[shard_id];
Z
zhaocaibei123 已提交
782 783
              for (size_t i = 0; i < keys.size(); i++) {
                uint64_t key = keys[i].first;
784
                auto itr = local_shard.find(key);
Z
zhaocaibei123 已提交
785
                size_t data_size = value_size - mf_value_size;
786
                if (itr == local_shard.end()) {
Z
zhaocaibei123 已提交
787
                  // ++missed_keys;
788
                  if (FLAGS_pserver_create_value_when_push) {
Z
zhaocaibei123 已提交
789 790
                    memset(data_buffer, 0, sizeof(float) * data_size);
                  } else {
791
                    auto &feature_value = local_shard[key];
792
                    feature_value.resize(data_size);
793
                    float *data_ptr = feature_value.data();
794
                    _value_accesor->Create(&data_buffer_ptr, 1);
795 796
                    memcpy(
                        data_ptr, data_buffer_ptr, data_size * sizeof(float));
Z
zhaocaibei123 已提交
797 798
                  }
                } else {
799
                  data_size = itr.value().size();
800 801
                  memcpy(data_buffer_ptr,
                         itr.value().data(),
Z
zhaocaibei123 已提交
802 803
                         data_size * sizeof(float));
                }
804
                for (size_t mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
Z
zhaocaibei123 已提交
805 806 807
                  data_buffer[mf_idx] = 0.0;
                }
                auto offset = keys[i].second;
808
                float *select_data = pull_values + select_value_size * offset;
809
                _value_accesor->Select(
810
                    &select_data, (const float **)&data_buffer_ptr, 1);
Z
zhaocaibei123 已提交
811 812 813 814 815 816 817 818 819 820 821 822
              }

              return 0;
            });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

823 824
int32_t MemorySparseTable::PullSparsePtr(char **pull_values,
                                         const uint64_t *keys,
825
                                         size_t num) {
826
  CostTimer timer("pscore_sparse_select_all");
827 828 829
  size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_size =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
830 831 832 833 834 835 836 837 838

  std::vector<std::future<int>> tasks(_real_local_shard_num);
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
      _real_local_shard_num);
  for (size_t i = 0; i < num; ++i) {
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
    task_keys[shard_id].push_back({keys[i], i});
  }
  // std::atomic<uint32_t> missed_keys{0};
839
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
840 841
    tasks[shard_id] =
        _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
842 843 844 845 846
            [this,
             shard_id,
             &task_keys,
             pull_values,
             value_size,
847
             mf_value_size]() -> int {
848 849
              auto &keys = task_keys[shard_id];
              auto &local_shard = _local_shards[shard_id];
R
Ruibiao Chen 已提交
850
              float data_buffer[value_size];  // NOLINT
851
              float *data_buffer_ptr = data_buffer;
852
              for (size_t i = 0; i < keys.size(); ++i) {
853 854 855
                uint64_t key = keys[i].first;
                auto itr = local_shard.find(key);
                size_t data_size = value_size - mf_value_size;
856
                FixedFeatureValue *ret = NULL;
857 858
                if (itr == local_shard.end()) {
                  // ++missed_keys;
859
                  auto &feature_value = local_shard[key];
860
                  feature_value.resize(data_size);
861
                  float *data_ptr = feature_value.data();
862
                  _value_accesor->Create(&data_buffer_ptr, 1);
863 864 865 866 867 868
                  memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
                  ret = &feature_value;
                } else {
                  ret = itr.value_ptr();
                }
                int pull_data_idx = keys[i].second;
869
                pull_values[pull_data_idx] = reinterpret_cast<char *>(ret);
870 871 872 873 874 875 876
              }
              return 0;
            });
  }
  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
Z
zhaocaibei123 已提交
877 878 879
  return 0;
}

880 881
int32_t MemorySparseTable::PushSparse(const uint64_t *keys,
                                      const float *values,
Z
zhaocaibei123 已提交
882
                                      size_t num) {
883 884
  CostTimer timer("pserver_sparse_update_all");
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
885
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
886
      _real_local_shard_num);
Z
zhaocaibei123 已提交
887
  for (size_t i = 0; i < num; ++i) {
888
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
889 890 891
    task_keys[shard_id].push_back({keys[i], i});
  }

892 893 894 895
  const size_t value_col =
      _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
896
  size_t update_value_col =
897
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
898

899
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
900
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
901 902 903 904 905 906
        [this,
         shard_id,
         value_col,
         mf_value_col,
         update_value_col,
         values,
Z
zhaocaibei123 已提交
907
         &task_keys]() -> int {
908 909 910
          auto &keys = task_keys[shard_id];
          auto &local_shard = _local_shards[shard_id];
          auto &local_shard_new = _local_shards_new[shard_id];
Z
zhaocaibei123 已提交
911
          float data_buffer[value_col];  // NOLINT
912
          float *data_buffer_ptr = data_buffer;
913
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
914 915
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
916
            const float *update_data =
Z
zhaocaibei123 已提交
917
                values + push_data_idx * update_value_col;
918 919 920
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
921
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
922 923 924
                continue;
              }
              auto value_size = value_col - mf_value_col;
925
              auto &feature_value = local_shard[key];
926
              feature_value.resize(value_size);
927
              _value_accesor->Create(&data_buffer_ptr, 1);
928 929
              memcpy(feature_value.data(),
                     data_buffer_ptr,
Z
zhaocaibei123 已提交
930
                     value_size * sizeof(float));
931
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
932 933
            }

934 935
            auto &feature_value = itr.value();
            float *value_data = feature_value.data();
936
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
937 938

            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
939
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
940 941 942
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
943
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
Z
zhaocaibei123 已提交
944

945
              if (_value_accesor->NeedExtendMF(data_buffer)) {
946 947
                feature_value.resize(value_col);
                value_data = feature_value.data();
948
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
949 950 951
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
Z
zhaocaibei123 已提交
952
            if (_config.enable_revert()) {
953
              FixedFeatureValue *feature_value_new = &(local_shard_new[key]);
Z
zhaocaibei123 已提交
954 955 956 957 958 959
              auto new_size = feature_value.size();
              feature_value_new->resize(new_size);
              memcpy(feature_value_new->data(),
                     value_data,
                     new_size * sizeof(float));
            }
Z
zhaocaibei123 已提交
960 961 962 963 964 965 966 967 968 969 970
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

971 972
int32_t MemorySparseTable::PushSparse(const uint64_t *keys,
                                      const float **values,
973
                                      size_t num) {
974
  std::vector<std::future<int>> tasks(_real_local_shard_num);
Z
zhaocaibei123 已提交
975
  std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
976
      _real_local_shard_num);
Z
zhaocaibei123 已提交
977
  for (size_t i = 0; i < num; ++i) {
978
    int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
Z
zhaocaibei123 已提交
979 980 981
    task_keys[shard_id].push_back({keys[i], i});
  }

982 983 984
  size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
  size_t mf_value_col =
      _value_accesor->GetAccessorInfo().mf_size / sizeof(float);
985
  size_t update_value_col =
986
      _value_accesor->GetAccessorInfo().update_size / sizeof(float);
Z
zhaocaibei123 已提交
987

988 989
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
990 991 992 993 994 995
        [this,
         shard_id,
         value_col,
         mf_value_col,
         update_value_col,
         values,
Z
zhaocaibei123 已提交
996
         &task_keys]() -> int {
997 998
          auto &keys = task_keys[shard_id];
          auto &local_shard = _local_shards[shard_id];
Z
zhaocaibei123 已提交
999
          float data_buffer[value_col];  // NOLINT
1000
          float *data_buffer_ptr = data_buffer;
1001
          for (size_t i = 0; i < keys.size(); ++i) {
Z
zhaocaibei123 已提交
1002 1003
            uint64_t key = keys[i].first;
            uint64_t push_data_idx = keys[i].second;
1004
            const float *update_data = values[push_data_idx];
1005 1006 1007
            auto itr = local_shard.find(key);
            if (itr == local_shard.end()) {
              if (FLAGS_pserver_enable_create_feasign_randomly &&
1008
                  !_value_accesor->CreateValue(1, update_data)) {
Z
zhaocaibei123 已提交
1009 1010 1011
                continue;
              }
              auto value_size = value_col - mf_value_col;
1012
              auto &feature_value = local_shard[key];
1013
              feature_value.resize(value_size);
1014
              _value_accesor->Create(&data_buffer_ptr, 1);
1015 1016
              memcpy(feature_value.data(),
                     data_buffer_ptr,
Z
zhaocaibei123 已提交
1017
                     value_size * sizeof(float));
1018
              itr = local_shard.find(key);
Z
zhaocaibei123 已提交
1019
            }
1020 1021
            auto &feature_value = itr.value();
            float *value_data = feature_value.data();
1022
            size_t value_size = feature_value.size();
Z
zhaocaibei123 已提交
1023
            if (value_size == value_col) {  // 已拓展到最大size, 则就地update
1024
              _value_accesor->Update(&value_data, &update_data, 1);
Z
zhaocaibei123 已提交
1025 1026 1027
            } else {
              // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
              memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
1028 1029
              _value_accesor->Update(&data_buffer_ptr, &update_data, 1);
              if (_value_accesor->NeedExtendMF(data_buffer)) {
1030 1031
                feature_value.resize(value_col);
                value_data = feature_value.data();
1032
                _value_accesor->Create(&value_data, 1);
Z
zhaocaibei123 已提交
1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
              }
              memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
            }
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
1047
int32_t MemorySparseTable::Flush() { return 0; }
Z
zhaocaibei123 已提交
1048

1049
int32_t MemorySparseTable::Shrink(const std::string &param) {
Z
zhaocaibei123 已提交
1050
  VLOG(0) << "MemorySparseTable::Shrink";
Z
zhaocaibei123 已提交
1051
  // TODO(zhaocaibei123): implement with multi-thread
1052
  for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
Z
zhaocaibei123 已提交
1053
    // Shrink
1054
    auto &shard = _local_shards[shard_id];
1055
    for (auto it = shard.begin(); it != shard.end();) {
1056
      if (_value_accesor->Shrink(it.value().data())) {
1057 1058 1059
        it = shard.erase(it);
      } else {
        ++it;
Z
zhaocaibei123 已提交
1060 1061 1062 1063 1064 1065
      }
    }
  }
  return 0;
}

Z
zhaocaibei123 已提交
1066
void MemorySparseTable::Clear() { VLOG(0) << "clear coming soon"; }
Z
zhaocaibei123 已提交
1067 1068 1069

}  // namespace distributed
}  // namespace paddle