downpour_worker.cc 23.2 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/device_worker.h"
16
#include "paddle/fluid/framework/device_worker_factory.h"
17 18
#include "paddle/fluid/platform/cpu_helper.h"

19 20 21 22 23
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif

24 25 26
namespace paddle {
namespace framework {

27
void DownpourWorker::Initialize(const TrainerDesc& desc) {
28
  param_ = desc.downpour_param();
D
dongdaxiang 已提交
29
  for (int i = 0; i < param_.sparse_table_size(); ++i) {
30 31 32 33
    uint64_t table_id =
        static_cast<uint64_t>(param_.sparse_table(i).table_id());
    TableParameter table = param_.sparse_table(i);
    sparse_key_names_[table_id].resize(table.sparse_key_name_size());
D
dongdaxiang 已提交
34
    for (int j = 0; j < table.sparse_key_name_size(); ++j) {
35 36 37
      sparse_key_names_[table_id][j] = table.sparse_key_name(j);
    }
    sparse_value_names_[table_id].resize(table.sparse_value_name_size());
D
dongdaxiang 已提交
38
    for (int j = 0; j < table.sparse_value_name_size(); ++j) {
39 40 41
      sparse_value_names_[table_id][j] = table.sparse_value_name(j);
    }
    sparse_grad_names_[table_id].resize(table.sparse_grad_name_size());
D
dongdaxiang 已提交
42
    for (int j = 0; j < table.sparse_grad_name_size(); ++j) {
43 44
      sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
    }
45
    label_var_name_[table_id] = table.label_var_name();
46 47
  }

D
dongdaxiang 已提交
48
  for (int i = 0; i < param_.dense_table_size(); ++i) {
49 50 51
    uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
    auto table = param_.dense_table(i);
    dense_value_names_[table_id].resize(table.dense_value_name_size());
D
dongdaxiang 已提交
52
    for (int j = 0; j < table.dense_value_name_size(); ++j) {
53 54 55
      dense_value_names_[table_id][j] = table.dense_value_name(j);
    }
    dense_grad_names_[table_id].resize(table.dense_grad_name_size());
D
dongdaxiang 已提交
56
    for (int j = 0; j < table.dense_grad_name_size(); ++j) {
57 58 59 60 61
      dense_grad_names_[table_id][j] = table.dense_grad_name(j);
    }
  }

  skip_ops_.resize(param_.skip_ops_size());
D
dongdaxiang 已提交
62
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
63 64
    skip_ops_[i] = param_.skip_ops(i);
  }
65

66 67 68
  need_to_push_sparse_ = param_.push_sparse();
  need_to_push_dense_ = param_.push_dense();

69
  fleet_ptr_ = FleetWrapper::GetInstance();
D
dongdaxiang 已提交
70
  fetch_config_ = desc.fetch_config();
71
  use_cvm_ = desc.use_cvm();
72
  scale_datanorm_ = desc.scale_datanorm();
T
Thunderbrook 已提交
73
  dump_slot_ = desc.dump_slot();
74
  adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
75 76
}

77
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
H
heqiaozhi 已提交
78
  uint64_t table_id = static_cast<uint64_t>(
79
      param_.program_config(0).pull_sparse_table_id(table_idx));
80

H
heqiaozhi 已提交
81 82 83 84 85 86 87
  TableParameter table;
  for (auto i : param_.sparse_table()) {
    if (i.table_id() == table_id) {
      table = i;
      break;
    }
  }
88 89 90
  auto& feature = features_[table_id];
  auto& feature_label = feature_labels_[table_id];
  feature_label.resize(feature.size());
91
  Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
92 93 94
  LoDTensor* tensor = var->GetMutable<LoDTensor>();
  int64_t* label_ptr = tensor->data<int64_t>();

D
dongdaxiang 已提交
95
  size_t global_index = 0;
96
  for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
97 98
    VLOG(3) << "sparse_key_names_[" << i
            << "]: " << sparse_key_names_[table_id][i];
99
    Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
100 101 102
    if (fea_var == nullptr) {
      continue;
    }
103
    LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
104 105
    CHECK(tensor != nullptr) << "tensor of var "
                             << sparse_key_names_[table_id][i] << " is null";
106
    int64_t* ids = tensor->data<int64_t>();
D
dongdaxiang 已提交
107
    size_t fea_idx = 0;
108
    // tensor->lod()[0].size() == batch_size + 1
109 110
    for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
      for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
111 112 113 114
        // should be skipped feasign defined in protobuf
        if (ids[fea_idx] == 0u) {
          continue;
        }
115 116
        feature_label[global_index++] =
            static_cast<float>(label_ptr[lod_idx - 1]);
117 118 119 120 121 122 123 124
      }
    }
  }
  CHECK(global_index == feature.size())
      << "expect fea info size:" << feature.size() << " real:" << global_index;
}

void DownpourWorker::FillSparseValue(size_t table_idx) {
H
heqiaozhi 已提交
125
  uint64_t table_id = static_cast<uint64_t>(
126
      param_.program_config(0).pull_sparse_table_id(table_idx));
H
heqiaozhi 已提交
127 128 129 130 131 132 133 134

  TableParameter table;
  for (auto i : param_.sparse_table()) {
    if (i.table_id() == table_id) {
      table = i;
      break;
    }
  }
135 136 137 138

  auto& fea_value = feature_values_[table_id];
  auto fea_idx = 0u;

X
xjqbest 已提交
139
  std::vector<float> init_value(table.fea_dim());
140 141 142 143
  for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
    std::string slot_name = sparse_key_names_[table_id][i];
    std::string emb_slot_name = sparse_value_names_[table_id][i];
    Variable* var = thread_scope_->FindVar(slot_name);
144 145 146
    if (var == nullptr) {
      continue;
    }
147
    LoDTensor* tensor = var->GetMutable<LoDTensor>();
148
    CHECK(tensor != nullptr) << "tensor of var " << slot_name << " is null";
149 150 151 152 153 154 155 156 157 158
    int64_t* ids = tensor->data<int64_t>();
    int len = tensor->numel();
    Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
    LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
    float* ptr = tensor_emb->mutable_data<float>({len, table.emb_dim()},
                                                 platform::CPUPlace());
    memset(ptr, 0, sizeof(float) * len * table.emb_dim());
    auto& tensor_lod = tensor->lod()[0];
    LoD data_lod{tensor_lod};
    tensor_emb->set_lod(data_lod);
159 160 161 162 163 164 165 166

    bool is_nid = (adjust_ins_weight_config_.need_adjust() &&
                   adjust_ins_weight_config_.nid_slot() == emb_slot_name);
    if (is_nid) {
      nid_show_.clear();
    }
    int nid_ins_index = 0;

D
dongdaxiang 已提交
167
    for (int index = 0; index < len; ++index) {
168 169 170 171
      if (use_cvm_) {
        if (ids[index] == 0u) {
          memcpy(ptr + table.emb_dim() * index, init_value.data(),
                 sizeof(float) * table.emb_dim());
172 173 174 175
          if (is_nid) {
            nid_show_.push_back(-1);
            ++nid_ins_index;
          }
176 177 178 179
          continue;
        }
        memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(),
               sizeof(float) * table.emb_dim());
180 181 182 183
        if (is_nid && index == tensor->lod()[0][nid_ins_index]) {
          nid_show_.push_back(fea_value[fea_idx][0]);
          ++nid_ins_index;
        }
184 185 186 187 188
        fea_idx++;
      } else {
        if (ids[index] == 0u) {
          memcpy(ptr + table.emb_dim() * index, init_value.data() + 2,
                 sizeof(float) * table.emb_dim());
189 190 191 192
          if (is_nid) {
            nid_show_.push_back(-1);
            ++nid_ins_index;
          }
193 194 195
          continue;
        }
        memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2,
196
               sizeof(float) * table.emb_dim());
197 198 199 200
        if (is_nid && index == tensor->lod()[0][nid_ins_index]) {
          nid_show_.push_back(fea_value[fea_idx][0]);
          ++nid_ins_index;
        }
201
        fea_idx++;
202 203 204 205 206
      }
    }
  }
}

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
void DownpourWorker::AdjustInsWeight() {
#ifdef _LINUX
  // check var and tensor not null
  if (!adjust_ins_weight_config_.need_adjust()) {
    VLOG(0) << "need_adjust=false, skip adjust ins weight";
    return;
  }
  Variable* nid_var =
      thread_scope_->FindVar(adjust_ins_weight_config_.nid_slot());
  if (nid_var == nullptr) {
    VLOG(0) << "nid slot var " << adjust_ins_weight_config_.nid_slot()
            << " is nullptr, skip adjust ins weight";
    return;
  }
  LoDTensor* nid_tensor = nid_var->GetMutable<LoDTensor>();
  if (nid_tensor == nullptr) {
    VLOG(0) << "tensor of nid slot var " << adjust_ins_weight_config_.nid_slot()
            << " is nullptr, skip adjust ins weight";
    return;
  }
  Variable* ins_weight_var =
      thread_scope_->FindVar(adjust_ins_weight_config_.ins_weight_slot());
  if (ins_weight_var == nullptr) {
    VLOG(0) << "ins weight var " << adjust_ins_weight_config_.ins_weight_slot()
            << " is nullptr, skip adjust ins weight";
    return;
  }
  LoDTensor* ins_weight_tensor = ins_weight_var->GetMutable<LoDTensor>();
  if (ins_weight_tensor == nullptr) {
    VLOG(0) << "tensor of ins weight tensor "
            << adjust_ins_weight_config_.ins_weight_slot()
            << " is nullptr, skip adjust ins weight";
    return;
  }

  float* ins_weights = ins_weight_tensor->data<float>();
  size_t len = ins_weight_tensor->numel();  // len = batch size
  // here we assume nid_show slot only has one feasign in each instance
  CHECK(len == nid_show_.size()) << "ins_weight size should be equal to "
                                 << "nid_show size, " << len << " vs "
                                 << nid_show_.size();
  float nid_adjw_threshold = adjust_ins_weight_config_.nid_adjw_threshold();
  float nid_adjw_ratio = adjust_ins_weight_config_.nid_adjw_ratio();
  int64_t nid_adjw_num = 0;
  double nid_adjw_weight = 0.0;
  size_t ins_index = 0;
  for (int i = 0; i < len; ++i) {
    float nid_show = nid_show_[i];
    VLOG(3) << "nid_show " << nid_show;
    if (nid_show < 0) {
      VLOG(3) << "nid_show < 0, continue";
      continue;
    }
    float ins_weight = 1.0;
    if (nid_show >= 0 && nid_show < nid_adjw_threshold) {
      ins_weight = log(M_E +
                       (nid_adjw_threshold - nid_show) / nid_adjw_threshold *
                           nid_adjw_ratio);
      // count nid adjw insnum and weight
      ++nid_adjw_num;
      nid_adjw_weight += ins_weight;
      // choose large ins weight
      VLOG(3) << "ins weight new " << ins_weight << ", ins weight origin "
              << ins_weights[ins_index];
      if (ins_weight > ins_weights[ins_index]) {
        VLOG(3) << "ins " << ins_index << " weight changes to " << ins_weight;
        ins_weights[ins_index] = ins_weight;
      }
      ++ins_index;
    }
  }
  VLOG(3) << "nid adjw info: total_adjw_num: " << nid_adjw_num
          << ", avg_adjw_weight: " << nid_adjw_weight;
#endif
}

283 284 285
void DownpourWorker::TrainFilesWithProfiler() {
  VLOG(3) << "Begin to train files with profiler";
  platform::SetNumThreads(1);
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
  device_reader_->Start();
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
  for (auto& op : ops_) {
    bool need_skip = false;
    for (auto t = 0u; t < skip_ops_.size(); ++t) {
      if (op->Type().find(skip_ops_[t]) != std::string::npos) {
        need_skip = true;
        break;
      }
    }
    if (!need_skip) {
      op_name.push_back(op->Type());
    }
  }

  VLOG(3) << "op name size: " << op_name.size();
  op_total_time.resize(op_name.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  double pull_sparse_time = 0.0;
311
  double adjust_ins_weight_time = 0.0;
312 313 314 315 316 317
  double collect_label_time = 0.0;
  double fill_sparse_time = 0.0;
  double push_sparse_time = 0.0;
  double push_dense_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
D
dongdaxiang 已提交
318
  uint64_t total_inst = 0;
319 320 321 322 323 324
  timeline.Start();
  while ((cur_batch = device_reader_->Next()) > 0) {
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    VLOG(3) << "program config size: " << param_.program_config_size();
D
dongdaxiang 已提交
325
    for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
326 327 328 329
         ++i) {
      uint64_t tid = static_cast<uint64_t>(
          param_.program_config(0).pull_sparse_table_id(i));
      TableParameter table;
330 331 332
      for (auto j : param_.sparse_table()) {
        if (j.table_id() == tid) {
          table = j;
333 334 335 336 337 338 339 340 341
          break;
        }
      }
      timeline.Start();
      fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
                                     sparse_key_names_[tid], &features_[tid],
                                     &feature_values_[tid], table.fea_dim());
      timeline.Pause();
      pull_sparse_time += timeline.ElapsedSec();
D
dongdaxiang 已提交
342
      total_time += timeline.ElapsedSec();
D
dongdaxiang 已提交
343
      timeline.Start();
344 345 346
      CollectLabelInfo(i);
      timeline.Pause();
      collect_label_time += timeline.ElapsedSec();
D
dongdaxiang 已提交
347
      total_time += timeline.ElapsedSec();
348 349 350 351
      timeline.Start();
      FillSparseValue(i);
      timeline.Pause();
      fill_sparse_time += timeline.ElapsedSec();
D
dongdaxiang 已提交
352
      total_time += timeline.ElapsedSec();
353 354 355 356 357 358 359 360 361 362
      timeline.Start();
      auto nid_iter = std::find(sparse_value_names_[tid].begin(),
                                sparse_value_names_[tid].end(),
                                adjust_ins_weight_config_.nid_slot());
      if (nid_iter != sparse_value_names_[tid].end()) {
        AdjustInsWeight();
      }
      timeline.Pause();
      adjust_ins_weight_time += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
363 364 365 366 367 368 369 370 371 372 373 374 375 376
    }
    VLOG(3) << "Fill sparse value for all sparse table done.";

    int run_op_idx = 0;
    for (auto& op : ops_) {
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        timeline.Start();
377
        VLOG(3) << "Going to run op " << op_name[run_op_idx];
378
        op->Run(*thread_scope_, place_);
379
        VLOG(3) << "Op " << op_name[run_op_idx] << " Finished";
380 381 382 383 384 385
        timeline.Pause();
        op_total_time[run_op_idx++] += timeline.ElapsedSec();
        total_time += timeline.ElapsedSec();
      }
    }

386
    if (need_to_push_sparse_) {
D
dongdaxiang 已提交
387 388
      for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size();
           ++i) {
389 390 391 392 393 394 395 396
        uint64_t tid = static_cast<uint64_t>(
            param_.program_config(0).push_sparse_table_id(i));
        TableParameter table;
        for (auto i : param_.sparse_table()) {
          if (i.table_id() == tid) {
            table = i;
            break;
          }
397
        }
398 399 400 401
        timeline.Start();
        fleet_ptr_->PushSparseVarsWithLabelAsync(
            *thread_scope_, tid, features_[tid], feature_labels_[tid],
            sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
T
Thunderbrook 已提交
402 403
            &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
            dump_slot_);
404 405 406
        timeline.Pause();
        push_sparse_time += timeline.ElapsedSec();
        total_time += timeline.ElapsedSec();
407
      }
408 409 410
    }

    if (need_to_push_dense_) {
411
      timeline.Start();
D
dongdaxiang 已提交
412 413
      for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
           ++i) {
414 415 416
        uint64_t tid = static_cast<uint64_t>(
            param_.program_config(0).push_dense_table_id(i));
        fleet_ptr_->PushDenseVarsAsync(
417 418
            *thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_,
            scale_datanorm_, cur_batch);
419
      }
420
      timeline.Pause();
421
      push_dense_time += timeline.ElapsedSec();
D
dongdaxiang 已提交
422
      total_time += timeline.ElapsedSec();
423 424 425 426 427 428 429 430 431
      VLOG(3) << "push sparse and dense gradient done.";
      int32_t tmp_push_dense_wait_times = -1;
      static uint32_t push_dense_wait_times =
          static_cast<uint32_t>(tmp_push_dense_wait_times);
      if (push_dense_status_.size() >= push_dense_wait_times) {
        for (auto& t : push_dense_status_) {
          t.wait();
        }
        push_dense_status_.resize(0);
432 433
      }

434 435
      if (tmp_push_dense_wait_times == -1) {
        push_dense_status_.resize(0);
436 437 438
      }
    }

439
    if (need_to_push_sparse_) {
440 441 442
      int32_t tmp_push_sparse_wait_times = -1;
      static uint32_t push_sparse_wait_times =
          static_cast<uint32_t>(tmp_push_sparse_wait_times);
443 444 445 446 447 448
      if (push_sparse_status_.size() >= push_sparse_wait_times) {
        for (auto& t : push_sparse_status_) {
          t.wait();
        }
        push_sparse_status_.resize(0);
      }
449

450 451 452
      if (tmp_push_sparse_wait_times == -1) {
        push_sparse_status_.resize(0);
      }
453

454 455 456
      VLOG(3) << "going to increase thread version";
      VLOG(3) << "push dense table id size: "
              << param_.program_config(0).push_dense_table_id_size();
457 458 459
    }

    if (need_to_push_dense_) {
D
dongdaxiang 已提交
460 461
      for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
           ++i) {
462 463 464 465
        uint64_t tid = static_cast<uint64_t>(
            param_.program_config(0).push_dense_table_id(i));
        pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
      }
466 467
    }

D
dongdaxiang 已提交
468
    PrintFetchVars();
469
    thread_scope_->DropKids();
D
dongdaxiang 已提交
470
    total_inst += cur_batch;
471 472 473 474 475
    ++batch_cnt;

    if (thread_id_ == 0) {
      // should be configured here
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
476 477
        double op_sum_time = 0;
        std::unordered_map<std::string, double> op_to_time;
478 479 480
        for (size_t i = 0; i < op_total_time.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
481 482 483 484 485 486 487 488 489
          if (op_to_time.find(op_name[i]) == op_to_time.end()) {
            op_to_time[op_name[i]] = 0.0;
          }
          op_to_time[op_name[i]] += op_total_time[i];
          op_sum_time += op_total_time[i];
        }
        for (auto& i : op_to_time) {
          fprintf(stderr, "op [%s] run total time: [%f]ms\n", i.first.c_str(),
                  i.second / batch_cnt);
490
        }
491 492 493 494 495 496 497 498 499 500 501
        fprintf(stderr, "op run total time: %fs\n", op_sum_time / batch_cnt);
        fprintf(stderr, "train total time: %fs\n", total_time / batch_cnt);
        fprintf(stderr, "pull sparse time: %fs\n",
                pull_sparse_time / batch_cnt);
        fprintf(stderr, "fill sparse time: %fs\n",
                fill_sparse_time / batch_cnt);
        fprintf(stderr, "push sparse time: %fs\n",
                push_sparse_time / batch_cnt);
        fprintf(stderr, "push dense time: %fs\n", push_dense_time / batch_cnt);
        fprintf(stderr, "collect label time: %fs\n",
                collect_label_time / batch_cnt);
502 503
        fprintf(stderr, "adjust ins weight time: %fs\n",
                adjust_ins_weight_time / batch_cnt);
504 505
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
506
        fprintf(stderr, "op run percent: %f\n", op_sum_time / total_time * 100);
D
dongdaxiang 已提交
507 508
        fprintf(stderr, "pull sparse time percent: %f\n",
                pull_sparse_time / total_time * 100);
509 510
        fprintf(stderr, "adjust ins weight time percent: %f\n",
                adjust_ins_weight_time / total_time * 100);
D
dongdaxiang 已提交
511 512 513 514 515 516 517 518
        fprintf(stderr, "collect label time percent: %f\n",
                collect_label_time / total_time * 100);
        fprintf(stderr, "fill sparse time percent: %f\n",
                fill_sparse_time / total_time * 100);
        fprintf(stderr, "push sparse time percent: %f\n",
                push_sparse_time / total_time * 100);
        fprintf(stderr, "push dense time percent: %f\n",
                push_dense_time / total_time * 100);
D
dongdaxiang 已提交
519
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
520 521
      }
    }
D
dongdaxiang 已提交
522
    timeline.Start();
523
  }
524 525
}

526
void DownpourWorker::TrainFiles() {
D
dongdaxiang 已提交
527
  VLOG(3) << "Begin to train files";
528
  platform::SetNumThreads(1);
529
  device_reader_->Start();
530 531
  int batch_cnt = 0;
  int cur_batch;
532
  while ((cur_batch = device_reader_->Next()) > 0) {
533
    // pull sparse here
D
dongdaxiang 已提交
534
    for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
H
heqiaozhi 已提交
535 536 537 538
         ++i) {
      uint64_t tid = static_cast<uint64_t>(
          param_.program_config(0).pull_sparse_table_id(i));
      TableParameter table;
539 540 541
      for (auto j : param_.sparse_table()) {
        if (j.table_id() == tid) {
          table = j;
H
heqiaozhi 已提交
542 543 544 545 546 547
          break;
        }
      }
      fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
                                     sparse_key_names_[tid], &features_[tid],
                                     &feature_values_[tid], table.fea_dim());
548 549
      CollectLabelInfo(i);
      FillSparseValue(i);
550 551 552 553 554 555
      auto nid_iter = std::find(sparse_value_names_[tid].begin(),
                                sparse_value_names_[tid].end(),
                                adjust_ins_weight_config_.nid_slot());
      if (nid_iter != sparse_value_names_[tid].end()) {
        AdjustInsWeight();
      }
556
    }
D
dongdaxiang 已提交
557
    VLOG(3) << "fill sparse value for all sparse table done.";
558 559 560

    // do computation here
    for (auto& op : ops_) {
561 562 563 564 565 566 567 568 569 570
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
571 572
    }

573 574
    if (need_to_push_sparse_) {
      // push gradients here
D
dongdaxiang 已提交
575 576
      for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size();
           ++i) {
577 578 579 580 581 582 583 584
        uint64_t tid = static_cast<uint64_t>(
            param_.program_config(0).push_sparse_table_id(i));
        TableParameter table;
        for (auto i : param_.sparse_table()) {
          if (i.table_id() == tid) {
            table = i;
            break;
          }
H
heqiaozhi 已提交
585
        }
586 587 588
        fleet_ptr_->PushSparseVarsWithLabelAsync(
            *thread_scope_, tid, features_[tid], feature_labels_[tid],
            sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
T
Thunderbrook 已提交
589 590
            &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
            dump_slot_);
H
heqiaozhi 已提交
591
      }
592 593
    }

594
    if (need_to_push_dense_) {
D
dongdaxiang 已提交
595 596
      for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
           ++i) {
597 598 599
        uint64_t tid = static_cast<uint64_t>(
            param_.program_config(0).push_dense_table_id(i));
        fleet_ptr_->PushDenseVarsAsync(
600 601
            *thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_,
            scale_datanorm_, cur_batch);
602 603 604
      }

      VLOG(3) << "push dense gradient done.";
605

606 607 608 609 610
      // the following code should be more precise and clean
      // TODO(guru4elephant)
      int32_t tmp_push_dense_wait_times = -1;
      static uint32_t push_dense_wait_times =
          static_cast<uint32_t>(tmp_push_dense_wait_times);
611

612 613 614 615 616
      if (push_dense_status_.size() >= push_dense_wait_times) {
        for (auto& t : push_dense_status_) {
          t.wait();
        }
        push_dense_status_.resize(0);
617 618
      }

619 620 621
      if (tmp_push_dense_wait_times == -1) {
        push_dense_status_.resize(0);
      }
622 623
    }

624 625 626 627 628 629 630 631 632 633
    if (need_to_push_sparse_) {
      VLOG(3) << "push sparse gradient done.";
      int32_t tmp_push_sparse_wait_times = -1;
      static uint32_t push_sparse_wait_times =
          static_cast<uint32_t>(tmp_push_sparse_wait_times);
      if (push_sparse_status_.size() >= push_sparse_wait_times) {
        for (auto& t : push_sparse_status_) {
          t.wait();
        }
        push_sparse_status_.resize(0);
634 635
      }

636 637 638
      if (tmp_push_sparse_wait_times == -1) {
        push_sparse_status_.resize(0);
      }
639 640
    }

641
    if (need_to_push_dense_) {
D
dongdaxiang 已提交
642 643
      for (int i = 0; i < param_.program_config(0).push_dense_table_id_size();
           ++i) {
644 645 646 647
        uint64_t tid = static_cast<uint64_t>(
            param_.program_config(0).push_dense_table_id(i));
        pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
      }
648
    }
649

D
dongdaxiang 已提交
650
    PrintFetchVars();
651 652 653 654 655 656 657
    thread_scope_->DropKids();
    ++batch_cnt;
  }
}

}  // end namespace framework
}  // end namespace paddle