common_table.h 4.7 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <condition_variable>  // NOLINT
#include <mutex>               // NOLINT
#include <set>

22
#include "paddle/fluid/distributed/ps/table/table.h"
T
tangwei12 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

#include "paddle/fluid/distributed/common/utils.h"

namespace paddle {
namespace distributed {

template <typename T>
struct ReservoirValue {
  std::vector<T> values;
  uint32_t counter;
  uint32_t dim;

  ReservoirValue() {
    dim = 0;
    values.resize(dim);
    counter = 0;
  }

  ReservoirValue(uint32_t dim) {
    this->dim = dim;
    values.resize(dim);
    counter = 0;
  }

  void add(const T *value, int numel) {
    GetBlas<T>().VADD(numel, values.data(), value, values.data());
    counter++;
  }

  void add(T *value, int numel) {
    GetBlas<T>().VADD(numel, values.data(), value, values.data());
    counter++;
  }

  void avg() {
T
tangwei12 已提交
58
    if (counter == 0) return;
T
tangwei12 已提交
59 60 61 62 63
    auto scale = 1 / static_cast<T>(counter);
    GetBlas<T>().SCAL(values.size(), scale, values.data());
  }

  void reset() {
T
tangwei12 已提交
64
    std::fill(values.begin(), values.end(), 0);
T
tangwei12 已提交
65 66 67 68 69 70 71 72 73
    counter = 0;
  }
};

class SparseTable : public Table {
 public:
  SparseTable() {}
  virtual ~SparseTable() {}

Z
zhaocaibei123 已提交
74
  virtual void *GetShard(size_t shard_idx) { return 0; }
T
tangwei12 已提交
75

Z
zhaocaibei123 已提交
76
  int32_t PullDense(float *values, size_t num) override { return 0; }
T
tangwei12 已提交
77

Z
zhaocaibei123 已提交
78
  int32_t PushDense(const float *values, size_t num) override { return 0; }
T
tangwei12 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

  static int32_t sparse_local_shard_num(uint32_t shard_num,
                                        uint32_t server_num) {
    if (shard_num % server_num == 0) {
      return shard_num / server_num;
    }
    size_t local_shard_num = shard_num / server_num + 1;
    return local_shard_num;
  }

  static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
                                 uint64_t key) {
    return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
  }
};

class DenseTable : public Table {
 public:
  DenseTable() {}
  virtual ~DenseTable() {}

Z
zhaocaibei123 已提交
100 101 102
  virtual void *GetShard(size_t shard_idx) { return 0; }
  int32_t PullSparse(float *values,
                     const PullSparseValue &pull_value) override {
T
tangwei12 已提交
103 104
    return 0;
  }
Z
zhaocaibei123 已提交
105 106
  int32_t PushSparse(const uint64_t *keys, const float *values,
                     size_t num) override {
T
tangwei12 已提交
107 108
    return 0;
  }
Z
zhaocaibei123 已提交
109 110
  int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
  int32_t Shrink(const std::string &param) override { return 0; }
T
tangwei12 已提交
111 112 113 114 115 116 117
};

class BarrierTable : public Table {
 public:
  BarrierTable() {}
  virtual ~BarrierTable() {}

Z
zhaocaibei123 已提交
118
  virtual void *GetShard(size_t shard_idx) { return 0; }
T
tangwei12 已提交
119

Y
yaoxuefeng 已提交
120 121 122
  virtual int32_t Pull(TableContext &context) { return 0; }
  virtual int32_t Push(TableContext &context) { return 0; }

Z
zhaocaibei123 已提交
123
  int32_t PullDense(float *values, size_t num) override { return 0; }
T
tangwei12 已提交
124

Z
zhaocaibei123 已提交
125
  int32_t PushDense(const float *values, size_t num) override { return 0; }
T
tangwei12 已提交
126

Z
zhaocaibei123 已提交
127 128
  int32_t PullSparse(float *values,
                     const PullSparseValue &pull_value) override {
T
tangwei12 已提交
129 130
    return 0;
  }
Z
zhaocaibei123 已提交
131 132
  int32_t PushSparse(const uint64_t *keys, const float *values,
                     size_t num) override {
T
tangwei12 已提交
133 134
    return 0;
  }
Z
zhaocaibei123 已提交
135 136 137 138 139
  int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
  int32_t Shrink(const std::string &param) override { return 0; }
  virtual void Clear() {}
  virtual int32_t Flush() { return 0; }
  virtual int32_t Load(const std::string &path, const std::string &param) {
T
tangwei12 已提交
140 141
    return 0;
  }
Z
zhaocaibei123 已提交
142
  virtual int32_t Save(const std::string &path, const std::string &param) {
T
tangwei12 已提交
143 144
    return 0;
  }
Z
zhaocaibei123 已提交
145
  virtual int32_t InitializeShard() { return 0; }
T
tangwei12 已提交
146

Z
zhaocaibei123 已提交
147
  virtual int32_t Initialize() override;
T
tangwei12 已提交
148 149
  // only for barrier
  // 0: send_barrier 1: recv_barrier 2: complete
Z
zhaocaibei123 已提交
150
  virtual int32_t Barrier(const uint32_t trainer_id,
T
tangwei12 已提交
151 152
                          const std::string barrier_type) override;

Z
zhaocaibei123 已提交
153
  virtual int32_t SetTableMap(
T
tangwei12 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166
      std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;

 private:
  std::mutex mutex_;
  std::condition_variable trainer_wait_;
  std::set<uint64_t> trainer_ids_;
  std::set<uint64_t> trainer_all_;
  std::atomic<int> trigger_;
  std::atomic<bool> exit_;
  std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map_;
};
}  // namespace distributed
}  // namespace paddle