selected_rows.h 5.1 KB
Newer Older
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Q
qijun 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

Q
qijun 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13 14 15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

Y
Yancey1989 已提交
17
#include <algorithm>
18
#include <memory>
Q
qiaolongfei 已提交
19
#include <mutex>  // NOLINT
20
#include <unordered_map>
Y
Yancey1989 已提交
21
#include <utility>
22 23
#include <vector>

Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/lod_tensor.h"
25
#include "paddle/fluid/framework/rw_lock.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/tensor.h"
Y
Yancey1989 已提交
27
#include "paddle/fluid/memory/memcpy.h"
Q
qijun 已提交
28 29 30 31 32

namespace paddle {
namespace framework {

class SelectedRows {
Y
Yancey1989 已提交
33 34
  /*
   * @brief We can use the SelectedRows structure to reproduce a sparse table.
X
Xin Pan 已提交
35
   *  A sparse table is a key-value structure that the key is an `int64_t`,
Y
Yancey1989 已提交
36 37 38 39 40 41 42
   *  and the value is a Tensor which the first dimension is 0.
   *  You can use the following interface to operate the sparse table, and you
   * can find
   *  some detail information from the comments of each interface:
   *
   *  HasKey(key), whether the sparse table has the specified key.
   *  Set(key, value), set a key-value pair into the sparse table.
Y
Yancey1989 已提交
43
   *  Get(keys, value*), get value by given key list and apply it to the given
Y
Yancey1989 已提交
44 45 46 47
   * value pointer
   *    with the specified offset.
   *
   */
Q
qijun 已提交
48 49 50 51
 public:
  SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
      : rows_(rows), height_(height) {
    value_.reset(new Tensor());
52
    rwlock_.reset(new RWLock);
Q
qijun 已提交
53 54
  }

Q
QI JUN 已提交
55 56 57
  SelectedRows() {
    height_ = 0;
    value_.reset(new Tensor());
58
    rwlock_.reset(new RWLock);
Q
QI JUN 已提交
59
  }
Q
qijun 已提交
60

61
  const platform::Place& place() const { return value_->place(); }
Q
qijun 已提交
62

Q
qijun 已提交
63 64 65
  const Tensor& value() const { return *value_; }

  Tensor* mutable_value() { return value_.get(); }
Q
qijun 已提交
66 67 68 69 70

  int64_t height() const { return height_; }

  void set_height(int64_t height) { height_ = height; }

Q
qijun 已提交
71
  const Vector<int64_t>& rows() const { return rows_; }
Q
qijun 已提交
72

Q
QI JUN 已提交
73 74
  Vector<int64_t>* mutable_rows() { return &rows_; }

Q
qijun 已提交
75
  void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
Q
qijun 已提交
76

Y
Yancey1989 已提交
77
  /*
78 79 80 81 82 83 84 85 86 87 88 89 90 91
   * @brief Get the index of key in rows
   *
   * @return -1 if the key does not exists.
   */
  int64_t Index(int64_t key) const {
    auto it = std::find(rows_.begin(), rows_.end(), key);
    if (it == rows_.end()) {
      PADDLE_THROW("id %s not in table", key);
    }
    return static_cast<int64_t>(std::distance(rows_.begin(), it));
  }

  /*
   * @brief whether has the specified key in the table.
Y
Yancey1989 已提交
92 93
   *
   * @return true if the key is exists.
Q
qiaolongfei 已提交
94
   */
Y
Yancey1989 已提交
95 96 97
  bool HasKey(int64_t key) const;

  /*
98 99 100 101
   * @brief Get value by the key list.
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
102
   *
Y
Yancey1989 已提交
103 104
   * @return a list of pair which contains the non-exists key and the index in
   * the value
Y
Yancey1989 已提交
105
   */
106
  void Get(const framework::Tensor& ids, framework::Tensor* value,
Q
Qiao Longfei 已提交
107
           bool auto_grown = false, bool is_test = false);
Y
Yancey1989 已提交
108 109

  /*
110 111 112
   * @brief Get the index of the key from id_to_index_ map. If the key not
   * exist,
   * add the key into id_to_index_.
Y
Yancey1989 已提交
113
   *
114 115 116
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
117
   *
118
   * @return index of the key.
Y
Yancey1989 已提交
119
   */
J
JiabinYang 已提交
120 121 122 123 124
  int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);

  /*
   * @brief Get the index of the key from id_to_index_ map.
   */
125
  inline int64_t GetIndexFromId(int64_t key) const {
J
JiabinYang 已提交
126 127
    auto iter = id_to_index_.find(key);
    if (iter == id_to_index_.end()) {
J
JiabinYang 已提交
128
      return -1;
J
JiabinYang 已提交
129
    } else {
J
JiabinYang 已提交
130
      return iter->second;
J
JiabinYang 已提交
131 132
    }
  }
Y
Yancey1989 已提交
133

134
  void SyncIndex();
J
JiabinYang 已提交
135 136 137
  /*
   * @brief Get complete Dims before
   */
Q
qijun 已提交
138 139 140 141 142 143 144
  DDim GetCompleteDims() const {
    std::vector<int64_t> dims = vectorize(value_->dims());
    dims[0] = height_;
    return make_ddim(dims);
  }

 private:
Q
qijun 已提交
145
  // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
146
  // SelectedRows are simply concated when adding together. Until a
Q
qijun 已提交
147
  // SelectedRows add a Tensor, will the duplicate rows be handled.
Q
qijun 已提交
148
  Vector<int64_t> rows_;
149
  std::unordered_map<int64_t, int64_t>
J
JiabinYang 已提交
150
      id_to_index_;  // should not be used when rows_ has duplicate member
Q
qijun 已提交
151
  std::unique_ptr<Tensor> value_{nullptr};
J
JiabinYang 已提交
152
  int64_t height_;  // height indicates the underline tensor's height
153
  std::unique_ptr<RWLock> rwlock_{nullptr};
Q
qijun 已提交
154 155
};

156 157 158 159 160 161 162
/*
 * Serialize/Desiralize SelectedRows to std::ostream
 * You can pass ofstream or ostringstream to serilize to file
 * or to a in memory string. GPU tensor will be copied to CPU.
 */
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
                       const platform::DeviceContext& dev_ctx);
Y
Yancey 已提交
163 164
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
                           const platform::DeviceContext& dev_ctx);
165

Q
qijun 已提交
166 167
}  // namespace framework
}  // namespace paddle