matrix_bit_code.h 8.2 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yu Yang 已提交
16
#include <map>
J
JiabinYang 已提交
17 18 19
#include <unordered_map>
#include <utility>
#include <vector>
20

W
weixing02 已提交
21
#include "paddle/fluid/framework/eigen.h"
J
JiabinYang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
W
weixing02 已提交
24 25
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
Y
Yu Yang 已提交
26
#include "paddle/fluid/platform/variant.h"
27
#include "paddle/phi/kernels/funcs/blas/blas.h"
Y
Yancey1989 已提交
28

D
dzhwinter 已提交
29 30
#if defined(_WIN32)
#include <intrin.h>
31 32 33
#ifndef NOMINMAX
#define NOMINMAX  // msvc max/min macro conflict with std::min/max
#endif
D
dzhwinter 已提交
34 35 36
#include <windows.h>
#endif  // _WIN32

Y
Yancey1989 已提交
37 38 39
namespace paddle {
namespace operators {
namespace math {
W
weixing02 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
/**
 * SimpleCodeTable class should support 3 functions:
 *
 * size_t size()
 *   return the number of ids
 *
 * int get_max_code_length()
 *   return the maximal code length
 *
 * SimpleCode operator()(size_t i)
 *   return the i-th code. Code class is descriebed below.
 *
 * SimpleCode class should support 3 functions:
 *
 * int get_length()
 *   return the length of the code
 *
 * size_t cal_index(int bit)
 *   bit ranges from 0 to get_length() - 1
 *   return the index for the (1+bit) level parent
 *
 * bool calc_bit(int bit)
 *   return true if the bit level parent is the right child of (1+bit) level
 *   parent
 *
 */
Y
Yancey1989 已提交
66 67 68 69 70 71

/**
 * return the 1-based index of the highest bit set
 *
 * for x > 0:
 * \f[
W
weixing02 已提交
72
 *    FindLastSet(x) = 1 + \floor*{\log_{2}x}
Y
Yancey1989 已提交
73 74
 * \f]
 */
D
dzhwinter 已提交
75
#if !defined(_WIN32)
Y
Yancey1989 已提交
76 77 78 79 80 81
inline constexpr size_t FindLastSet(size_t x) {
  return std::is_same<size_t, unsigned int>::value
             ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
             : (std::is_same<size_t, unsigned long>::value  // NOLINT
                    ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
                    : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
W
wopeizl 已提交
82
}
D
dzhwinter 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
#else
// windows don't have built-in clz, ctz function
template <typename T>
inline int ctz(const T& value) {
  DWORD trailing_zero = 0;
  if (_BitScanForward(&trailing_zero, value)) {
    return static_cast<int>(trailing_zero);
  } else {
    return static_cast<int>(0);
  }
}

template <typename T>
inline int clz(const T& value) {
  DWORD leadning_zero = 0;
  if (_BitScanReverse(&leadning_zero, value)) {
    return static_cast<int>(sizeof(T) * 8 - leadning_zero);
  } else {
    return static_cast<int>(0);
  }
}

105
inline size_t FindLastSet(size_t x) { return 1 + sizeof(size_t) * 8 - clz(x); }
D
dzhwinter 已提交
106
#endif  // !_WIN32
Y
Yu Yang 已提交
107
class SimpleCode {
108 109 110
 public:
  SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
      : c_(static_cast<size_t>(ids[code]) + num_classes) {}
G
guosheng 已提交
111
  /**
T
tianshuo78520a 已提交
112
   * Here the id of root should be 1 rather than 0, thus the encoding of class c
T
tianshuo78520a 已提交
113
   * is `c + num_classes` and all siblings can get the same weight index using
114 115 116 117 118
   * prefixes.
   * Weight index is the prefixes of encoding, thus leave out the right most
   * bit in calc_index.
   * Binary classification path is the suffixes of encoding, thus leave out the
   * left most bit in calc_bit.
G
guosheng 已提交
119
   */
120 121 122
  size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
  bool calc_bit(int bit) const { return c_ & (1 << bit); }
  int get_length() const { return FindLastSet(c_) - 1; }
Y
Yancey1989 已提交
123 124

 private:
125
  size_t c_;
Y
Yancey1989 已提交
126 127
};

J
JiabinYang 已提交
128
template <typename T>
Y
Yu Yang 已提交
129
class CustomCode {
130
 public:
131 132 133 134 135 136
  CustomCode(const framework::Tensor& path_table,
             const framework::Tensor& path_code, const int64_t* ids,
             int index) {
    seq_len_ = path_table.dims()[1];
    path_table_data_ = path_table.data<T>() + seq_len_ * index;
    path_code_data_ = path_code.data<T>() + seq_len_ * index;
J
JiabinYang 已提交
137
  }
138
  /**
Y
Yu Yang 已提交
139
   * Here the id of root should be 1 rather than 0, thus the encoding of class c
T
tianshuo78520a 已提交
140
   * is `c + num_classes` and all siblings can get the same weight index using
141 142 143 144 145 146
   * prefixes.
   * Weight index is the prefixes of encoding, thus leave out the right most
   * bit in calc_index.
   * Binary classification path is the suffixes of encoding, thus leave out the
   * left most bit in calc_bit.
   */
147 148
  size_t calc_index(int bit) const { return path_table_data_[bit]; }
  bool calc_bit(int bit) const { return path_code_data_[bit]; }
149

Y
Yu Yang 已提交
150
  // NOTE: this function is not thread-safe.
Y
Yu Yang 已提交
151
  int get_length() const {
Y
Yu Yang 已提交
152 153
    if (length_ < 0) {
      auto len = seq_len_;
154 155 156 157
      length_ = static_cast<int>(
          std::find_if(path_table_data_, path_table_data_ + len,
                       [](const T& val) { return val < 0; }) -
          path_table_data_);
158
    }
Y
Yu Yang 已提交
159
    return length_;
160 161 162
  }

 private:
Y
Yu Yang 已提交
163
  int64_t seq_len_;
164 165
  const T* path_table_data_;
  const T* path_code_data_;
Y
Yu Yang 已提交
166
  mutable int length_{-1};
167 168
};

Y
Yu Yang 已提交
169
class SimpleCodeTable {
170
 public:
J
JiabinYang 已提交
171
  SimpleCodeTable(size_t num_classes, const int64_t* ids)
172
      : num_classes_(num_classes), ids_(ids) {}
Y
Yu Yang 已提交
173

Y
Yu Yang 已提交
174 175
  SimpleCode get_code(int64_t code) const {
    return SimpleCode(code, num_classes_, ids_);
Y
Yancey1989 已提交
176
  }
Y
Yu Yang 已提交
177

Y
Yancey1989 已提交
178 179 180 181 182
  size_t size() const { return num_classes_; }
  int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }

 private:
  size_t num_classes_;
183 184 185
  const int64_t* ids_;
};

J
JiabinYang 已提交
186
template <typename T>
Y
Yu Yang 已提交
187
class CustomCodeTable {
188
 public:
189 190 191
  CustomCodeTable(const framework::Tensor& path_table,
                  const framework::Tensor& path_code, const int64_t* ids)
      : ptable_(path_table), pcode_(path_code), ids_(ids) {}
192

Y
Yu Yang 已提交
193 194
  CustomCode<T> get_code(int64_t code) const {
    return CustomCode<T>(ptable_, pcode_, ids_, code);
195 196
  }

J
JiabinYang 已提交
197
  size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
198
  int get_max_code_length() const {
J
JiabinYang 已提交
199
    return static_cast<size_t>(ptable_.dims()[1]);
200 201 202
  }

 private:
J
JiabinYang 已提交
203 204
  const framework::Tensor& ptable_;
  const framework::Tensor& pcode_;
205
  const int64_t* ids_;
Y
Yancey1989 已提交
206 207
};

Y
Yu Yang 已提交
208 209
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;

Y
Yancey1989 已提交
210
template <typename T>
Y
Yancey1989 已提交
211 212
class MatrixBitCodeFunctor {
 public:
J
JiabinYang 已提交
213
  MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
214 215
      : num_classes_(num_classes),
        ids_(ids),
Y
Yu Yang 已提交
216
        code_table_(SimpleCodeTable(num_classes, ids)) {}
217

218 219 220
  MatrixBitCodeFunctor(const framework::Tensor& path_table,
                       const framework::Tensor& path_code, const int64_t* ids)
      : num_classes_(static_cast<size_t>(path_table.dims()[1])),
221
        ids_(ids),
222
        code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
Y
Yancey1989 已提交
223 224 225
  /* For j < code_length
       tmat(i, j) += vec(0, index(i, j))
  */
J
JiabinYang 已提交
226
  void Add(const framework::Tensor& vec, framework::Tensor* tmat);
Y
Yancey1989 已提交
227

Y
Yancey1989 已提交
228 229 230
  /* For j < code_length
       vec(0, index(i, j)) += tmat(i, j)
  */
J
JiabinYang 已提交
231
  void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
Y
Yancey1989 已提交
232 233

  /* For j < code_length
Y
Yancey1989 已提交
234
    sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
Y
Yancey1989 已提交
235
  */
J
JiabinYang 已提交
236
  void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
Y
Yancey1989 已提交
237

Y
Yancey1989 已提交
238 239 240
  /* For j < code_length
       tmat(i, j) -= bit(i, j)
  */
J
JiabinYang 已提交
241
  void Sub(framework::Tensor* tmat);
Y
Yancey1989 已提交
242 243 244
  /* For j < code_length
       input.row(i) += tmat(i, j) * weight.row(index(i, j))
  */
J
JiabinYang 已提交
245 246
  void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
           const framework::Tensor& input);
Y
Yancey1989 已提交
247

Y
Yancey1989 已提交
248 249 250
  /* For index(i, j) >= 0:
      weight.row(index(i, j)) += tmat(i, j) * input.row(i)
  */
J
JiabinYang 已提交
251 252
  void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
                     const framework::Tensor& input);
J
JiabinYang 已提交
253 254 255
  /* For SelectedRows Weight, For index(i, j) >= 0:
      weight.row(index(i, j)) += tmat(i, j) * input.row(i)
  */
256
  void MulGradWeight(const framework::Tensor& tmat, phi::SelectedRows* weight,
J
JiabinYang 已提交
257
                     const framework::Tensor& input);
Y
Yancey1989 已提交
258 259 260
  /* For j < code_length
    input.row(i) += tmat(i, j) * weight.row(index(i, j))
  */
J
JiabinYang 已提交
261 262
  void MulGradError(const framework::Tensor& tmat,
                    const framework::Tensor& weight, framework::Tensor* input);
W
weixing02 已提交
263

Y
Yancey1989 已提交
264 265
  size_t num_classes_;
  const int64_t* ids_;
Y
Yu Yang 已提交
266
  CodeTable code_table_;
Y
Yancey1989 已提交
267
};
Y
Yancey1989 已提交
268 269 270
}  // namespace math
}  // namespace operators
}  // namespace paddle