common_table.h 2.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <condition_variable>  // NOLINT
#include <mutex>               // NOLINT
#include <set>

22
#include "paddle/fluid/distributed/ps/table/table.h"
T
tangwei12 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

#include "paddle/fluid/distributed/common/utils.h"

namespace paddle {
namespace distributed {

template <typename T>
struct ReservoirValue {
  std::vector<T> values;
  uint32_t counter;
  uint32_t dim;

  ReservoirValue() {
    dim = 0;
    values.resize(dim);
    counter = 0;
  }

  ReservoirValue(uint32_t dim) {
    this->dim = dim;
    values.resize(dim);
    counter = 0;
  }

  void add(const T *value, int numel) {
    GetBlas<T>().VADD(numel, values.data(), value, values.data());
    counter++;
  }

  void add(T *value, int numel) {
    GetBlas<T>().VADD(numel, values.data(), value, values.data());
    counter++;
  }

  void avg() {
T
tangwei12 已提交
58
    if (counter == 0) return;
T
tangwei12 已提交
59 60 61 62 63
    auto scale = 1 / static_cast<T>(counter);
    GetBlas<T>().SCAL(values.size(), scale, values.data());
  }

  void reset() {
T
tangwei12 已提交
64
    std::fill(values.begin(), values.end(), 0);
T
tangwei12 已提交
65 66 67 68 69 70 71 72 73
    counter = 0;
  }
};

class BarrierTable : public Table {
 public:
  BarrierTable() {}
  virtual ~BarrierTable() {}

Z
zhaocaibei123 已提交
74
  virtual void *GetShard(size_t shard_idx) { return 0; }
T
tangwei12 已提交
75

Y
yaoxuefeng 已提交
76 77 78
  virtual int32_t Pull(TableContext &context) { return 0; }
  virtual int32_t Push(TableContext &context) { return 0; }

Z
zhaocaibei123 已提交
79 80 81 82
  int32_t Shrink(const std::string &param) override { return 0; }
  virtual void Clear() {}
  virtual int32_t Flush() { return 0; }
  virtual int32_t Load(const std::string &path, const std::string &param) {
T
tangwei12 已提交
83 84
    return 0;
  }
Z
zhaocaibei123 已提交
85
  virtual int32_t Save(const std::string &path, const std::string &param) {
T
tangwei12 已提交
86 87
    return 0;
  }
Z
zhaocaibei123 已提交
88
  virtual int32_t InitializeShard() { return 0; }
T
tangwei12 已提交
89

Z
zhaocaibei123 已提交
90
  virtual int32_t Initialize() override;
T
tangwei12 已提交
91 92
  // only for barrier
  // 0: send_barrier 1: recv_barrier 2: complete
Z
zhaocaibei123 已提交
93
  virtual int32_t Barrier(const uint32_t trainer_id,
T
tangwei12 已提交
94 95
                          const std::string barrier_type) override;

Z
zhaocaibei123 已提交
96
  virtual int32_t SetTableMap(
T
tangwei12 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109
      std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;

 private:
  std::mutex mutex_;
  std::condition_variable trainer_wait_;
  std::set<uint64_t> trainer_ids_;
  std::set<uint64_t> trainer_all_;
  std::atomic<int> trigger_;
  std::atomic<bool> exit_;
  std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map_;
};
}  // namespace distributed
}  // namespace paddle