selected_rows_utils.h 5.3 KB
Newer Older
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Q
qijun 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

Q
qijun 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13 14 15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

Y
Yancey1989 已提交
17
#include <algorithm>
18
#include <memory>
Q
qiaolongfei 已提交
19
#include <mutex>  // NOLINT
20
#include <unordered_map>
Y
Yancey1989 已提交
21
#include <utility>
22 23
#include <vector>

Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/lod_tensor.h"
25
#include "paddle/fluid/framework/rw_lock.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/tensor.h"
Y
Yancey1989 已提交
27
#include "paddle/fluid/memory/memcpy.h"
28
#include "paddle/fluid/platform/place.h"
W
wanghuancoder 已提交
29

Q
qijun 已提交
30 31 32 33
namespace paddle {
namespace framework {

class SelectedRows {
Y
Yancey1989 已提交
34 35
  /*
   * @brief We can use the SelectedRows structure to reproduce a sparse table.
X
Xin Pan 已提交
36
   *  A sparse table is a key-value structure that the key is an `int64_t`,
Y
Yancey1989 已提交
37 38 39 40 41 42 43
   *  and the value is a Tensor which the first dimension is 0.
   *  You can use the following interface to operate the sparse table, and you
   * can find
   *  some detail information from the comments of each interface:
   *
   *  HasKey(key), whether the sparse table has the specified key.
   *  Set(key, value), set a key-value pair into the sparse table.
Y
Yancey1989 已提交
44
   *  Get(keys, value*), get value by given key list and apply it to the given
Y
Yancey1989 已提交
45 46 47 48
   * value pointer
   *    with the specified offset.
   *
   */
Q
qijun 已提交
49 50 51 52
 public:
  SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
      : rows_(rows), height_(height) {
    value_.reset(new Tensor());
53
    rwlock_.reset(new RWLock);
Q
qijun 已提交
54 55
  }

Q
QI JUN 已提交
56 57 58
  SelectedRows() {
    height_ = 0;
    value_.reset(new Tensor());
59
    rwlock_.reset(new RWLock);
Q
QI JUN 已提交
60
  }
Q
qijun 已提交
61

62
  const platform::Place& place() const { return value_->place(); }
Q
qijun 已提交
63

Q
qijun 已提交
64 65 66
  const Tensor& value() const { return *value_; }

  Tensor* mutable_value() { return value_.get(); }
Q
qijun 已提交
67 68 69 70 71

  int64_t height() const { return height_; }

  void set_height(int64_t height) { height_ = height; }

Q
qijun 已提交
72
  const Vector<int64_t>& rows() const { return rows_; }
Q
qijun 已提交
73

Q
QI JUN 已提交
74 75
  Vector<int64_t>* mutable_rows() { return &rows_; }

Q
qijun 已提交
76
  void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
Q
qijun 已提交
77

Y
Yancey1989 已提交
78
  /*
79 80 81 82 83 84 85
   * @brief Get the index of key in rows
   *
   * @return -1 if the key does not exists.
   */
  int64_t Index(int64_t key) const {
    auto it = std::find(rows_.begin(), rows_.end(), key);
    if (it == rows_.end()) {
86 87
      PADDLE_THROW(platform::errors::NotFound(
          "Input id (%lld) is not in current rows table.", key));
88 89 90 91 92 93
    }
    return static_cast<int64_t>(std::distance(rows_.begin(), it));
  }

  /*
   * @brief whether has the specified key in the table.
Y
Yancey1989 已提交
94 95
   *
   * @return true if the key is exists.
Q
qiaolongfei 已提交
96
   */
Y
Yancey1989 已提交
97 98 99
  bool HasKey(int64_t key) const;

  /*
100 101 102 103
   * @brief Get value by the key list.
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
104
   *
Y
Yancey1989 已提交
105 106
   * @return a list of pair which contains the non-exists key and the index in
   * the value
Y
Yancey1989 已提交
107
   */
108
  void Get(const framework::Tensor& ids, framework::Tensor* value,
Q
Qiao Longfei 已提交
109
           bool auto_grown = false, bool is_test = false);
Y
Yancey1989 已提交
110 111

  /*
112 113 114
   * @brief Get the index of the key from id_to_index_ map. If the key not
   * exist,
   * add the key into id_to_index_.
Y
Yancey1989 已提交
115
   *
116 117 118
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
119
   *
120
   * @return index of the key.
Y
Yancey1989 已提交
121
   */
J
JiabinYang 已提交
122 123 124 125 126
  int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);

  /*
   * @brief Get the index of the key from id_to_index_ map.
   */
127
  inline int64_t GetIndexFromId(int64_t key) const {
J
JiabinYang 已提交
128 129
    auto iter = id_to_index_.find(key);
    if (iter == id_to_index_.end()) {
J
JiabinYang 已提交
130
      return -1;
J
JiabinYang 已提交
131
    } else {
J
JiabinYang 已提交
132
      return iter->second;
J
JiabinYang 已提交
133 134
    }
  }
Y
Yancey1989 已提交
135

136
  void SyncIndex();
J
JiabinYang 已提交
137 138 139
  /*
   * @brief Get complete Dims before
   */
Q
qijun 已提交
140 141 142 143 144 145 146
  DDim GetCompleteDims() const {
    std::vector<int64_t> dims = vectorize(value_->dims());
    dims[0] = height_;
    return make_ddim(dims);
  }

 private:
Q
qijun 已提交
147
  // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
148
  // SelectedRows are simply concated when adding together. Until a
Q
qijun 已提交
149
  // SelectedRows add a Tensor, will the duplicate rows be handled.
Q
qijun 已提交
150
  Vector<int64_t> rows_;
151
  std::unordered_map<int64_t, int64_t>
J
JiabinYang 已提交
152
      id_to_index_;  // should not be used when rows_ has duplicate member
Q
qijun 已提交
153
  std::unique_ptr<Tensor> value_{nullptr};
J
JiabinYang 已提交
154
  int64_t height_;  // height indicates the underline tensor's height
155
  std::unique_ptr<RWLock> rwlock_{nullptr};
Q
qijun 已提交
156 157
};

158 159 160 161 162 163 164
/*
 * Serialize/Desiralize SelectedRows to std::ostream
 * You can pass ofstream or ostringstream to serilize to file
 * or to a in memory string. GPU tensor will be copied to CPU.
 */
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
                       const platform::DeviceContext& dev_ctx);
Y
Yancey 已提交
165 166
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
                           const platform::DeviceContext& dev_ctx);
167

168 169
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows);

170
void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows);
171

Q
qijun 已提交
172 173
}  // namespace framework
}  // namespace paddle