selected_rows.h 4.8 KB
Newer Older
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Q
qijun 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

Q
qijun 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

Q
qijun 已提交
9 10 11 12 13 14 15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

Y
Yancey1989 已提交
17
#include <algorithm>
18
#include <memory>
Q
qiaolongfei 已提交
19
#include <mutex>  // NOLINT
20
#include <unordered_map>
Y
Yancey1989 已提交
21
#include <utility>
22 23
#include <vector>

Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/lod_tensor.h"
25
#include "paddle/fluid/framework/rw_lock.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/tensor.h"
Y
Yancey1989 已提交
27
#include "paddle/fluid/memory/memcpy.h"
Q
qijun 已提交
28 29 30 31 32

namespace paddle {
namespace framework {

class SelectedRows {
Y
Yancey1989 已提交
33 34 35 36 37 38 39 40 41 42 43
  /*
   * @brief We can use the SelectedRows structure to reproduce a sparse table.
   *  A sparse table is a key-value structure that the key is an `int64_t`
   * number,
   *  and the value is a Tensor which the first dimension is 0.
   *  You can use the following interface to operate the sparse table, and you
   * can find
   *  some detail information from the comments of each interface:
   *
   *  HasKey(key), whether the sparse table has the specified key.
   *  Set(key, value), set a key-value pair into the sparse table.
Y
Yancey1989 已提交
44
   *  Get(keys, value*), get value by given key list and apply it to the given
Y
Yancey1989 已提交
45 46 47 48
   * value pointer
   *    with the specified offset.
   *
   */
Q
qijun 已提交
49 50 51 52
 public:
  SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
      : rows_(rows), height_(height) {
    value_.reset(new Tensor());
53
    rwlock_.reset(new RWLock);
Q
qijun 已提交
54 55
  }

Q
QI JUN 已提交
56 57 58
  SelectedRows() {
    height_ = 0;
    value_.reset(new Tensor());
59
    rwlock_.reset(new RWLock);
Q
QI JUN 已提交
60
  }
Q
qijun 已提交
61 62 63

  platform::Place place() const { return value_->place(); }

Q
qijun 已提交
64 65 66
  const Tensor& value() const { return *value_; }

  Tensor* mutable_value() { return value_.get(); }
Q
qijun 已提交
67 68 69 70 71

  int64_t height() const { return height_; }

  void set_height(int64_t height) { height_ = height; }

Q
qijun 已提交
72
  const Vector<int64_t>& rows() const { return rows_; }
Q
qijun 已提交
73

Q
QI JUN 已提交
74 75
  Vector<int64_t>* mutable_rows() { return &rows_; }

Q
qijun 已提交
76
  void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
Q
qijun 已提交
77

Y
Yancey1989 已提交
78
  /*
79 80 81 82 83 84 85 86 87 88 89 90 91 92
   * @brief Get the index of key in rows
   *
   * @return -1 if the key does not exists.
   */
  int64_t Index(int64_t key) const {
    auto it = std::find(rows_.begin(), rows_.end(), key);
    if (it == rows_.end()) {
      PADDLE_THROW("id %s not in table", key);
    }
    return static_cast<int64_t>(std::distance(rows_.begin(), it));
  }

  /*
   * @brief whether has the specified key in the table.
Y
Yancey1989 已提交
93 94
   *
   * @return true if the key is exists.
Q
qiaolongfei 已提交
95
   */
Y
Yancey1989 已提交
96 97 98
  bool HasKey(int64_t key) const;

  /*
99 100 101 102
   * @brief Get value by the key list.
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
103
   *
Y
Yancey1989 已提交
104 105
   * @return a list of pair which contains the non-exists key and the index in
   * the value
Y
Yancey1989 已提交
106
   */
107
  void Get(const framework::Tensor& ids, framework::Tensor* value,
Q
Qiao Longfei 已提交
108
           bool auto_grown = false, bool is_test = false);
Y
Yancey1989 已提交
109 110

  /*
111 112 113
   * @brief Get the index of the key from id_to_index_ map. If the key not
   * exist,
   * add the key into id_to_index_.
Y
Yancey1989 已提交
114
   *
115 116 117
   * Note!!! this interface is only used when selected_rows is used as
   * parameters
   * for distribute lookup table.
Y
Yancey1989 已提交
118
   *
119
   * @return index of the key.
Y
Yancey1989 已提交
120
   */
Q
Qiao Longfei 已提交
121
  int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
Y
Yancey1989 已提交
122

123
  void SyncIndex();
J
JiabinYang 已提交
124 125 126
  /*
   * @brief Get complete Dims before
   */
Q
qijun 已提交
127 128 129 130 131 132 133
  DDim GetCompleteDims() const {
    std::vector<int64_t> dims = vectorize(value_->dims());
    dims[0] = height_;
    return make_ddim(dims);
  }

 private:
Q
qijun 已提交
134
  // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
135
  // SelectedRows are simply concated when adding together. Until a
Q
qijun 已提交
136
  // SelectedRows add a Tensor, will the duplicate rows be handled.
Q
qijun 已提交
137
  Vector<int64_t> rows_;
138 139
  std::unordered_map<int64_t, int64_t>
      id_to_index_;  // should not be used when ids has duplicate member
Q
qijun 已提交
140
  std::unique_ptr<Tensor> value_{nullptr};
J
JiabinYang 已提交
141
  int64_t height_;  // height indicates the underline tensor's height
142
  std::unique_ptr<RWLock> rwlock_{nullptr};
Q
qijun 已提交
143 144
};

145 146 147 148 149 150 151
/*
 * Serialize/Desiralize SelectedRows to std::ostream
 * You can pass ofstream or ostringstream to serilize to file
 * or to a in memory string. GPU tensor will be copied to CPU.
 */
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
                       const platform::DeviceContext& dev_ctx);
Y
Yancey 已提交
152 153
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
                           const platform::DeviceContext& dev_ctx);
154

Q
qijun 已提交
155 156
}  // namespace framework
}  // namespace paddle