selected_rows.h 5.3 KB
Newer Older
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Q
qijun 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

Q
qijun 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13 14 15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

Y
Yancey1989 已提交
17
#include <algorithm>
18
#include <memory>
Q
qiaolongfei 已提交
19
#include <mutex>  // NOLINT
20
#include <unordered_map>
Y
Yancey1989 已提交
21
#include <utility>
22 23
#include <vector>

Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/lod_tensor.h"
25
#include "paddle/fluid/framework/rw_lock.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/tensor.h"
Y
Yancey1989 已提交
27
#include "paddle/fluid/memory/memcpy.h"
Q
qijun 已提交
28

W
wanghuancoder 已提交
29 30 31 32 33 34 35
namespace paddle {
namespace platform {
class DeviceContext;
class Place;
}  // namespace platform
}  // namespace paddle

Q
qijun 已提交
36 37 38
namespace paddle {
namespace framework {

W
wanghuancoder 已提交
39 40
class Tensor;

Q
qijun 已提交
41
class SelectedRows {
Y
Yancey1989 已提交
42 43
  /*
   * @brief We can use the SelectedRows structure to reproduce a sparse table.
X
Xin Pan 已提交
44
   *  A sparse table is a key-value structure that the key is an `int64_t`,
Y
Yancey1989 已提交
45 46 47 48 49 50 51
   *  and the value is a Tensor which the first dimension is 0.
   *  You can use the following interface to operate the sparse table, and you
   * can find
   *  some detail information from the comments of each interface:
   *
   *  HasKey(key), whether the sparse table has the specified key.
   *  Set(key, value), set a key-value pair into the sparse table.
Y
Yancey1989 已提交
52
   *  Get(keys, value*), get value by given key list and apply it to the given
Y
Yancey1989 已提交
53 54 55 56
   * value pointer
   *    with the specified offset.
   *
   */
Q
qijun 已提交
57 58 59 60
 public:
  SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
      : rows_(rows), height_(height) {
    value_.reset(new Tensor());
61
    rwlock_.reset(new RWLock);
Q
qijun 已提交
62 63
  }

Q
QI JUN 已提交
64 65 66
  SelectedRows() {
    height_ = 0;
    value_.reset(new Tensor());
67
    rwlock_.reset(new RWLock);
Q
QI JUN 已提交
68
  }
Q
qijun 已提交
69

70
  const platform::Place& place() const { return value_->place(); }
Q
qijun 已提交
71

Q
qijun 已提交
72 73 74
  const Tensor& value() const { return *value_; }

  Tensor* mutable_value() { return value_.get(); }
Q
qijun 已提交
75 76 77 78 79

  int64_t height() const { return height_; }

  void set_height(int64_t height) { height_ = height; }

Q
qijun 已提交
80
  const Vector<int64_t>& rows() const { return rows_; }
Q
qijun 已提交
81

Q
QI JUN 已提交
82 83
  Vector<int64_t>* mutable_rows() { return &rows_; }

Q
qijun 已提交
84
  void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
Q
qijun 已提交
85

Y
Yancey1989 已提交
86
  /*
87 88 89 90 91 92 93
   * @brief Get the index of key in rows
   *
   * @return -1 if the key does not exists.
   */
  int64_t Index(int64_t key) const {
    auto it = std::find(rows_.begin(), rows_.end(), key);
    if (it == rows_.end()) {
94 95
      PADDLE_THROW(platform::errors::NotFound(
          "Input id (%lld) is not in current rows table.", key));
96 97 98 99 100 101
    }
    return static_cast<int64_t>(std::distance(rows_.begin(), it));
  }

  /*
   * @brief whether has the specified key in the table.
Y
Yancey1989 已提交
102 103
   *
   * @return true if the key is exists.
Q
qiaolongfei 已提交
104
   */
Y
Yancey1989 已提交
105 106 107
  bool HasKey(int64_t key) const;

  /*
108 109 110 111
   * @brief Get value by the key list.
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
112
   *
Y
Yancey1989 已提交
113 114
   * @return a list of pair which contains the non-exists key and the index in
   * the value
Y
Yancey1989 已提交
115
   */
116
  void Get(const framework::Tensor& ids, framework::Tensor* value,
Q
Qiao Longfei 已提交
117
           bool auto_grown = false, bool is_test = false);
Y
Yancey1989 已提交
118 119

  /*
120 121 122
   * @brief Get the index of the key from id_to_index_ map. If the key not
   * exist,
   * add the key into id_to_index_.
Y
Yancey1989 已提交
123
   *
124 125 126
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
127
   *
128
   * @return index of the key.
Y
Yancey1989 已提交
129
   */
J
JiabinYang 已提交
130 131 132 133 134
  int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);

  /*
   * @brief Get the index of the key from id_to_index_ map.
   */
135
  inline int64_t GetIndexFromId(int64_t key) const {
J
JiabinYang 已提交
136 137
    auto iter = id_to_index_.find(key);
    if (iter == id_to_index_.end()) {
J
JiabinYang 已提交
138
      return -1;
J
JiabinYang 已提交
139
    } else {
J
JiabinYang 已提交
140
      return iter->second;
J
JiabinYang 已提交
141 142
    }
  }
Y
Yancey1989 已提交
143

144
  void SyncIndex();
J
JiabinYang 已提交
145 146 147
  /*
   * @brief Get complete Dims before
   */
Q
qijun 已提交
148 149 150 151 152 153 154
  DDim GetCompleteDims() const {
    std::vector<int64_t> dims = vectorize(value_->dims());
    dims[0] = height_;
    return make_ddim(dims);
  }

 private:
Q
qijun 已提交
155
  // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
156
  // SelectedRows are simply concated when adding together. Until a
Q
qijun 已提交
157
  // SelectedRows add a Tensor, will the duplicate rows be handled.
Q
qijun 已提交
158
  Vector<int64_t> rows_;
159
  std::unordered_map<int64_t, int64_t>
J
JiabinYang 已提交
160
      id_to_index_;  // should not be used when rows_ has duplicate member
Q
qijun 已提交
161
  std::unique_ptr<Tensor> value_{nullptr};
J
JiabinYang 已提交
162
  int64_t height_;  // height indicates the underline tensor's height
163
  std::unique_ptr<RWLock> rwlock_{nullptr};
Q
qijun 已提交
164 165
};

166 167 168 169 170 171 172
/*
 * Serialize/Desiralize SelectedRows to std::ostream
 * You can pass ofstream or ostringstream to serilize to file
 * or to a in memory string. GPU tensor will be copied to CPU.
 */
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
                       const platform::DeviceContext& dev_ctx);
Y
Yancey 已提交
173 174
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
                           const platform::DeviceContext& dev_ctx);
175

Q
qijun 已提交
176 177
}  // namespace framework
}  // namespace paddle