common_sparse_table.cc 19.2 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/table/common_sparse_table.h"
T
tangwei12 已提交
16
#include <sstream>
17 18 19 20 21 22 23 24 25

#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace distributed {
class ValueBlock;
}  // namespace distributed
}  // namespace paddle
T
tangwei12 已提交
26 27 28 29

namespace paddle {
namespace distributed {

T
Thunderbrook 已提交
30 31 32
void CommonSparseTable::ProcessALine(const std::vector<std::string>& columns,
                                     const Meta& meta, const int64_t id,
                                     std::vector<std::vector<float>>* values) {
33 34 35
  auto colunmn_size = columns.size();
  auto load_values =
      paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
36
  values->reserve(meta.names.size());
T
tangwei12 已提交
37

38 39
  int offset = 0;
  for (int x = 0; x < meta.names.size(); ++x) {
T
tangwei12 已提交
40
    std::vector<float> val;
41 42 43
    auto start = load_values.begin() + offset;
    auto end = load_values.begin() + offset + meta.dims[x];
    PADDLE_ENFORCE_LE(offset + meta.dims[x], load_values.size(),
T
tangwei12 已提交
44
                      paddle::platform::errors::InvalidArgument(
45 46 47
                          "The data format in txt does not meet the field "
                          "requirements defined in meta"));

T
tangwei12 已提交
48 49 50 51
    std::transform(start, end, std::back_inserter(val), [id](std::string va) {
      float v = 0.0;

      try {
52 53 54 55 56
        v = std::stof(va);
      } catch (std::invalid_argument& e) {
        VLOG(0) << "id: " << id << " get unexpected value: " << va
                << " and be reset to: 0.0";
      } catch (std::out_of_range& e) {
T
tangwei12 已提交
57 58 59 60 61 62
        VLOG(0) << "id: " << id << " get unexpected value: " << va
                << " and be reset to: 0.0";
      }
      return v;
    });

T
tangwei12 已提交
63
    values->push_back(val);
64
    offset += meta.dims[x];
T
tangwei12 已提交
65 66 67
  }
}

T
Thunderbrook 已提交
68 69 70 71
void CommonSparseTable::SaveMetaToText(std::ostream* os,
                                       const CommonAccessorParameter& common,
                                       const size_t shard_idx,
                                       const int64_t total) {
72 73 74 75 76 77 78 79 80 81 82
  // save meta
  std::stringstream stream;
  stream << "param=" << common.table_name() << "\n";
  stream << "shard_id=" << shard_idx << "\n";
  stream << "row_names=" << paddle::string::join_strings(common.params(), ',')
         << "\n";
  stream << "row_dims=" << paddle::string::join_strings(common.dims(), ',')
         << "\n";
  stream << "count=" << total << "\n";
  os->write(stream.str().c_str(), sizeof(char) * stream.str().size());
}
T
tangwei12 已提交
83

T
Thunderbrook 已提交
84 85 86 87
int64_t CommonSparseTable::SaveValueToText(std::ostream* os,
                                           std::shared_ptr<ValueBlock> block,
                                           std::shared_ptr<::ThreadPool> pool,
                                           const int mode, int shard_id) {
88
  int64_t save_num = 0;
T
Thunderbrook 已提交
89 90 91 92 93 94
  for (auto& table : block->values_) {
    for (auto& value : table) {
      if (mode == SaveMode::delta && !value.second->need_save_) {
        continue;
      }

T
tangwei12 已提交
95 96
      ++save_num;

T
Thunderbrook 已提交
97
      std::stringstream ss;
T
tangwei12 已提交
98 99
      auto* vs = value.second->data_.data();

T
Thunderbrook 已提交
100
      auto id = value.first;
T
tangwei12 已提交
101

T
Thunderbrook 已提交
102 103 104 105
      ss << id << "\t" << value.second->count_ << "\t"
         << value.second->unseen_days_ << "\t" << value.second->is_entry_
         << "\t";

T
tangwei12 已提交
106 107
      for (int i = 0; i < block->value_length_ - 1; i++) {
        ss << std::to_string(vs[i]) << ",";
T
Thunderbrook 已提交
108
      }
109

T
tangwei12 已提交
110
      ss << std::to_string(vs[block->value_length_ - 1]);
T
Thunderbrook 已提交
111
      ss << "\n";
112

T
Thunderbrook 已提交
113
      os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
114

T
Thunderbrook 已提交
115 116 117
      if (mode == SaveMode::base || mode == SaveMode::delta) {
        value.second->need_save_ = false;
      }
118
    }
T
tangwei12 已提交
119 120
  }

T
Thunderbrook 已提交
121
  return save_num;
T
tangwei12 已提交
122 123
}

T
Thunderbrook 已提交
124 125 126 127
int64_t CommonSparseTable::LoadFromText(
    const std::string& valuepath, const std::string& metapath,
    const int pserver_id, const int pserver_num, const int local_shard_num,
    std::vector<std::shared_ptr<ValueBlock>>* blocks) {
T
tangwei12 已提交
128 129 130 131 132 133 134 135
  Meta meta = Meta(metapath);

  int num_lines = 0;
  std::ifstream file(valuepath);
  std::string line;

  while (std::getline(file, line)) {
    auto values = paddle::string::split_string<std::string>(line, "\t");
136
    auto id = std::stoull(values[0]);
T
tangwei12 已提交
137 138

    if (id % pserver_num != pserver_id) {
139
      VLOG(3) << "will not load " << values[0] << " from " << valuepath
T
tangwei12 已提交
140 141 142 143 144 145 146 147
              << ", please check id distribution";
      continue;
    }

    auto shard_id = id % local_shard_num;
    auto block = blocks->at(shard_id);

    std::vector<std::vector<float>> kvalues;
T
tangwei12 已提交
148
    ProcessALine(values, meta, id, &kvalues);
149 150 151

    block->Init(id, false);

T
Thunderbrook 已提交
152
    VALUE* value_instant = block->GetValue(id);
T
tangwei12 已提交
153

154
    if (values.size() == 5) {
155 156 157
      value_instant->count_ = std::stoi(values[1]);
      value_instant->unseen_days_ = std::stoi(values[2]);
      value_instant->is_entry_ = static_cast<bool>(std::stoi(values[3]));
158 159 160 161 162 163 164
    }

    std::vector<float*> block_values = block->Get(id, meta.names, meta.dims);
    auto blas = GetBlas<float>();
    for (int x = 0; x < meta.names.size(); ++x) {
      blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]);
    }
T
tangwei12 已提交
165 166 167 168 169
  }

  return 0;
}

Z
zhaocaibei123 已提交
170
int32_t CommonSparseTable::Initialize() {
T
tangwei12 已提交
171 172 173 174 175 176 177 178
  _shards_task_pool.resize(task_pool_size_);
  for (int i = 0; i < _shards_task_pool.size(); ++i) {
    _shards_task_pool[i].reset(new ::ThreadPool(1));
  }

  sync = _config.common().sync();
  VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;

179 180
  _global_lr = new float(1.0);

T
tangwei12 已提交
181 182 183
  auto common = _config.common();
  int size = static_cast<int>(common.params().size());

T
tangwei12 已提交
184
  size_t offset = 0;
T
tangwei12 已提交
185 186 187
  for (int x = 0; x < size; ++x) {
    auto& varname = common.params()[x];
    auto& dim = common.dims()[x];
T
tangwei12 已提交
188 189 190 191 192 193 194

    value_idx_[varname] = x;
    value_names_.push_back(varname);
    value_dims_.push_back(dim);
    value_offsets_.push_back(offset);
    initializer_attrs_.push_back(common.initializers()[x]);

T
tangwei12 已提交
195 196
    if (varname == "Param") {
      param_dim_ = dim;
T
tangwei12 已提交
197
      param_offset_ = offset;
T
tangwei12 已提交
198
    }
T
tangwei12 已提交
199 200

    offset += dim;
T
tangwei12 已提交
201 202
  }

Z
zhaocaibei123 已提交
203 204 205
  InitializeValue();
  InitializeOptimizer();
  InitializeRecorder();
T
tangwei12 已提交
206 207 208
  return 0;
}

Z
zhaocaibei123 已提交
209
int32_t CommonSparseTable::InitializeRecorder() { return 0; }
T
tangwei12 已提交
210

Z
zhaocaibei123 已提交
211
int32_t CommonSparseTable::InitializeValue() {
T
tangwei12 已提交
212
  auto common = _config.common();
T
tangwei12 已提交
213
  shard_values_.reserve(task_pool_size_);
T
tangwei12 已提交
214

T
tangwei12 已提交
215
  for (int x = 0; x < task_pool_size_; ++x) {
T
tangwei12 已提交
216 217 218
    auto shard = std::make_shared<ValueBlock>(
        value_names_, value_dims_, value_offsets_, value_idx_,
        initializer_attrs_, common.entry());
T
tangwei12 已提交
219

T
tangwei12 已提交
220 221
    shard_values_.emplace_back(shard);
  }
T
tangwei12 已提交
222

T
tangwei12 已提交
223 224 225
  return 0;
}

Z
zhaocaibei123 已提交
226
int32_t CommonSparseTable::InitializeOptimizer() {
T
tangwei12 已提交
227 228 229 230
  auto common = _config.common();
  auto name = common.name();

  if (name == "sgd") {
T
tangwei12 已提交
231 232
    optimizer_ = std::make_shared<SSGD>(value_names_, value_dims_,
                                        value_offsets_, value_idx_);
Z
zhaocaibei123 已提交
233
    optimizer_->SetGlobalLR(_global_lr);
T
tangwei12 已提交
234
  } else if (name == "adam") {
T
tangwei12 已提交
235 236
    optimizer_ = std::make_shared<SAdam>(value_names_, value_dims_,
                                         value_offsets_, value_idx_);
Z
zhaocaibei123 已提交
237
    optimizer_->SetGlobalLR(_global_lr);
T
tangwei12 已提交
238
  } else if (name == "sum") {
T
tangwei12 已提交
239 240
    optimizer_ = std::make_shared<SSUM>(value_names_, value_dims_,
                                        value_offsets_, value_idx_);
T
tangwei12 已提交
241
  } else {
242
    VLOG(3) << "init optimizer failed";
T
tangwei12 已提交
243 244
  }

245
  VLOG(3) << "init optimizer " << name << " done";
T
tangwei12 已提交
246 247 248
  return 0;
}

Z
zhaocaibei123 已提交
249
int32_t CommonSparseTable::SetGlobalLR(float* lr) {
250
  _global_lr = lr;
Z
zhaocaibei123 已提交
251
  optimizer_->SetGlobalLR(_global_lr);
252 253 254
  return 0;
}

Z
zhaocaibei123 已提交
255
int32_t CommonSparseTable::Load(const std::string& dirname,
T
tangwei12 已提交
256
                                const std::string& param) {
257
  auto begin = GetCurrentUS();
T
tangwei12 已提交
258
  rwlock_->WRLock();
259 260 261 262 263 264 265 266 267
  auto varname = _config.common().table_name();
  std::string var_store =
      string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
  std::string shard_var_pre =
      string::Sprintf("%s.block%d", varname, _shard_idx);
  std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
  std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);

  LoadFromText(value_, meta_, _shard_idx, _shard_num, task_pool_size_,
T
tangwei12 已提交
268 269
               &shard_values_);
  rwlock_->UNLock();
270 271
  auto end = GetCurrentUS();

272 273
  VLOG(0) << "load " << varname << " with value: " << value_
          << " , meta: " << meta_
274 275
          << " using: " << std::to_string((end - begin) / 1e+6) << " seconds";

T
tangwei12 已提交
276 277 278
  return 0;
}

Z
zhaocaibei123 已提交
279
int32_t CommonSparseTable::Save(const std::string& dirname,
T
tangwei12 已提交
280
                                const std::string& param) {
281
  auto begin = GetCurrentUS();
T
tangwei12 已提交
282 283
  rwlock_->WRLock();
  int mode = std::stoi(param);
284
  VLOG(3) << "sparse table save: " << dirname << " mode: " << mode;
T
tangwei12 已提交
285 286

  auto varname = _config.common().table_name();
287 288
  std::string var_store =
      string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
T
tangwei12 已提交
289 290 291 292 293
  MkDirRecursively(var_store.c_str());

  VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
  std::vector<std::string> params(_config.common().params().begin(),
                                  _config.common().params().end());
294

T
tangwei12 已提交
295 296 297 298 299
  std::string shard_var_pre =
      string::Sprintf("%s.block%d", varname, _shard_idx);

  std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);

300
  std::unique_ptr<std::ofstream> vs(new std::ofstream(value_));
T
tangwei12 已提交
301 302 303 304

  int64_t total_ins = 0;
  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
    // save values
T
Thunderbrook 已提交
305 306 307
    auto shard_save_num =
        SaveValueToText(vs.get(), shard_values_[shard_id],
                        _shards_task_pool[shard_id], mode, shard_id);
308
    total_ins += shard_save_num;
T
tangwei12 已提交
309
  }
310
  vs->close();
T
tangwei12 已提交
311 312

  std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
313 314 315 316 317
  std::unique_ptr<std::ofstream> ms(new std::ofstream(meta_));
  SaveMetaToText(ms.get(), _config.common(), _shard_idx, total_ins);
  ms->close();

  auto end = GetCurrentUS();
T
tangwei12 已提交
318
  rwlock_->UNLock();
319 320 321
  VLOG(0) << "save " << varname << " with path: " << value_
          << " using: " << std::to_string((end - begin) / 1e+6) << " seconds";

T
tangwei12 已提交
322 323 324
  return 0;
}

Z
zhaocaibei123 已提交
325
std::pair<int64_t, int64_t> CommonSparseTable::PrintTableStat() {
T
tangwei12 已提交
326 327 328
  int64_t feasign_size = 0;
  int64_t mf_size = 0;

T
Thunderbrook 已提交
329 330 331 332
  for (auto& shard : shard_values_) {
    for (auto& table : shard->values_) {
      feasign_size += table.size();
    }
T
tangwei12 已提交
333 334 335 336 337
  }

  return {feasign_size, mf_size};
}

Z
zhaocaibei123 已提交
338
int32_t CommonSparseTable::Pour() {
T
tangwei12 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351
  std::vector<float> values;
  std::vector<uint64_t> keys;

  keys.reserve(pull_reservoir_.size());
  values.reserve(pull_reservoir_.size() * param_dim_);

  for (auto& val : pull_reservoir_) {
    keys.push_back(val.first);
    auto& reservoir = val.second;
    reservoir.avg();
    std::copy(reservoir.values.begin(), reservoir.values.end(),
              std::back_inserter(values));
  }
Z
zhaocaibei123 已提交
352
  _PushSparse(keys.data(), values.data(), pull_reservoir_.size());
T
tangwei12 已提交
353 354 355 356 357

  pull_reservoir_.clear();
  return 0;
}

Y
yaoxuefeng 已提交
358 359 360 361 362
int32_t CommonSparseTable::Pull(TableContext& context) {
  CHECK(context.value_type == Sparse);
  if (context.use_ptr) {
    char** pull_values = context.pull_context.ptr_values;
    const uint64_t* keys = context.pull_context.keys;
Z
zhaocaibei123 已提交
363
    return PullSparsePtr(pull_values, keys, context.num);
Y
yaoxuefeng 已提交
364 365 366
  } else {
    float* pull_values = context.pull_context.values;
    const PullSparseValue& pull_value = context.pull_context.pull_value;
Z
zhaocaibei123 已提交
367
    return PullSparse(pull_values, pull_value);
Y
yaoxuefeng 已提交
368 369 370 371 372
  }
}

int32_t CommonSparseTable::Push(TableContext& context) {
  CHECK(context.value_type == Sparse);
373
  if (context.push_context.values != nullptr) {
Y
yaoxuefeng 已提交
374 375
    const float* values = context.push_context.values;
    const uint64_t* keys = context.push_context.keys;
Z
zhaocaibei123 已提交
376
    return PushSparse(keys, values, context.num);
Y
yaoxuefeng 已提交
377 378 379
  } else {
    const float** values = context.push_context.ptr_values;
    const uint64_t* keys = context.push_context.keys;
Z
zhaocaibei123 已提交
380
    return PushSparse(keys, values, context.num);
Y
yaoxuefeng 已提交
381 382 383
  }
}

Z
zhaocaibei123 已提交
384 385
int32_t CommonSparseTable::PullSparse(float* pull_values,
                                      const PullSparseValue& pull_value) {
386 387
  auto shard_num = task_pool_size_;
  std::vector<std::future<int>> tasks(shard_num);
T
tangwei12 已提交
388

389
  for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
T
tangwei12 已提交
390
    tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
391
        [this, shard_id, shard_num, &pull_value, &pull_values]() -> int {
T
tangwei12 已提交
392
          auto& block = shard_values_[shard_id];
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411

          std::vector<int> offsets;
          pull_value.Fission(shard_id, shard_num, &offsets);

          if (pull_value.is_training_) {
            for (auto& offset : offsets) {
              auto feasign = pull_value.feasigns_[offset];
              auto frequencie = pull_value.frequencies_[offset];
              auto* value = block->Init(feasign, true, frequencie);
              std::copy_n(value + param_offset_, param_dim_,
                          pull_values + param_dim_ * offset);
            }
          } else {
            for (auto& offset : offsets) {
              auto feasign = pull_value.feasigns_[offset];
              auto* value = block->Init(feasign, false);
              std::copy_n(value + param_offset_, param_dim_,
                          pull_values + param_dim_ * offset);
            }
T
tangwei12 已提交
412
          }
T
tangwei12 已提交
413

T
tangwei12 已提交
414 415 416 417 418 419 420 421 422 423
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
424 425
int32_t CommonSparseTable::PullSparsePtr(char** pull_values,
                                         const uint64_t* keys, size_t num) {
T
Thunderbrook 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
  std::vector<std::vector<uint64_t>> offset_bucket;
  offset_bucket.resize(task_pool_size_);

  for (int x = 0; x < num; ++x) {
    auto y = keys[x] % task_pool_size_;
    offset_bucket[y].push_back(x);
  }

  std::vector<std::future<int>> tasks(task_pool_size_);

  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
        [this, shard_id, &keys, &offset_bucket, &pull_values]() -> int {
          auto& block = shard_values_[shard_id];
          auto& offsets = offset_bucket[shard_id];

          for (int i = 0; i < offsets.size(); ++i) {
            auto offset = offsets[i];
            auto id = keys[offset];
            auto* value = block->InitGet(id);
            // std::copy_n(value + param_offset_, param_dim_,
            //            pull_values + param_dim_ * offset);
T
tangwei12 已提交
448
            pull_values[offset] = reinterpret_cast<char*>(value);
T
Thunderbrook 已提交
449 450 451 452 453 454 455 456 457 458 459 460
          }

          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
461 462
int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
                                       const float* values, size_t num) {
T
tangwei12 已提交
463 464 465 466 467 468 469 470 471 472 473 474 475 476
  std::vector<std::vector<uint64_t>> offset_bucket;
  offset_bucket.resize(task_pool_size_);

  for (int x = 0; x < num; ++x) {
    auto y = keys[x] % task_pool_size_;
    offset_bucket[y].push_back(x);
  }

  std::vector<std::future<int>> tasks(task_pool_size_);

  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
        [this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
          auto& offsets = offset_bucket[shard_id];
Z
zhaocaibei123 已提交
477
          optimizer_->Update(keys, values, num, offsets,
T
tangwei12 已提交
478 479 480 481 482 483 484 485 486 487 488
                             shard_values_[shard_id].get());
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
489 490
int32_t CommonSparseTable::PushSparse(const uint64_t* keys, const float* values,
                                      size_t num) {
T
tangwei12 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
  if (sync) {
    std::future<int> task =
        _shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int {
          for (int x = 0; x < num; ++x) {
            auto id = keys[x];
            auto has = pull_reservoir_.find(id);

            if (has == pull_reservoir_.end()) {
              pull_reservoir_[id] = ReservoirValue<float>(param_dim_);
            }

            auto& reservoir = pull_reservoir_[id];
            reservoir.add(values + x * param_dim_, param_dim_);
          }
          return 0;
        });
    task.wait();
  } else {
Z
zhaocaibei123 已提交
509
    _PushSparse(keys, values, num);
T
tangwei12 已提交
510 511 512 513 514
  }

  return 0;
}

Z
zhaocaibei123 已提交
515 516 517
int32_t CommonSparseTable::PushSparse(const uint64_t* keys,
                                      const float** values, size_t num) {
  _PushSparse(keys, values, num);
T
Thunderbrook 已提交
518 519 520
  return 0;
}

Z
zhaocaibei123 已提交
521 522
int32_t CommonSparseTable::_PushSparse(const uint64_t* keys,
                                       const float** values, size_t num) {
T
Thunderbrook 已提交
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
  std::vector<std::vector<uint64_t>> offset_bucket;
  offset_bucket.resize(task_pool_size_);

  for (int x = 0; x < num; ++x) {
    auto y = keys[x] % task_pool_size_;
    offset_bucket[y].push_back(x);
  }

  std::vector<std::future<int>> tasks(task_pool_size_);

  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
        [this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
          auto& offsets = offset_bucket[shard_id];
          for (size_t i = 0; i < offsets.size(); ++i) {
            std::vector<uint64_t> tmp_off = {0};
Z
zhaocaibei123 已提交
539
            optimizer_->Update(keys + offsets[i], values[offsets[i]], num,
T
Thunderbrook 已提交
540 541 542 543 544 545 546 547 548 549 550 551
                               tmp_off, shard_values_[shard_id].get());
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
552 553
int32_t CommonSparseTable::PushSparseParam(const uint64_t* keys,
                                           const float* values, size_t num) {
T
tangwei12 已提交
554 555 556 557 558 559 560 561 562 563 564 565
  std::vector<std::vector<uint64_t>> offset_bucket;
  offset_bucket.resize(task_pool_size_);

  for (int x = 0; x < num; ++x) {
    auto y = keys[x] % task_pool_size_;
    offset_bucket[y].push_back(x);
  }

  std::vector<std::future<int>> tasks(task_pool_size_);

  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
    tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
T
tangwei12 已提交
566
        [this, shard_id, &keys, &offset_bucket, &values]() -> int {
T
tangwei12 已提交
567 568 569 570 571 572
          auto& block = shard_values_[shard_id];
          auto& offsets = offset_bucket[shard_id];

          for (int i = 0; i < offsets.size(); ++i) {
            auto offset = offsets[i];
            auto id = keys[offset];
573
            auto* value = block->Init(id, false);
T
tangwei12 已提交
574 575
            std::copy_n(values + param_dim_ * offset, param_dim_,
                        value + param_offset_);
576
            block->SetEntry(id, true);
T
tangwei12 已提交
577 578 579 580 581 582 583 584 585 586 587
          }
          return 0;
        });
  }

  for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
    tasks[shard_id].wait();
  }
  return 0;
}

Z
zhaocaibei123 已提交
588
int32_t CommonSparseTable::Flush() { return 0; }
T
tangwei12 已提交
589

Z
zhaocaibei123 已提交
590
int32_t CommonSparseTable::Shrink(const std::string& param) {
591
  int threshold = std::stoi(param);
Z
zhaocaibei123 已提交
592
  VLOG(3) << "sparse table Shrink: " << threshold;
593 594

  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
Z
zhaocaibei123 已提交
595 596
    // Shrink
    VLOG(4) << shard_id << " " << task_pool_size_ << " begin Shrink";
597 598
    shard_values_[shard_id]->Shrink(threshold);
  }
T
tangwei12 已提交
599 600
  return 0;
}
601

Z
zhaocaibei123 已提交
602
void CommonSparseTable::Clear() { VLOG(0) << "clear coming soon"; }
T
tangwei12 已提交
603 604 605

}  // namespace distributed
}  // namespace paddle