matrix_bit_code.h 8.3 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yu Yang 已提交
16
#include <map>
J
JiabinYang 已提交
17 18 19
#include <unordered_map>
#include <utility>
#include <vector>
20

W
weixing02 已提交
21
#include "paddle/fluid/framework/eigen.h"
J
JiabinYang 已提交
22
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/framework/selected_rows_utils.h"
W
weixing02 已提交
24 25
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
26

27
#include "paddle/phi/kernels/funcs/blas/blas.h"
Y
Yancey1989 已提交
28

D
dzhwinter 已提交
29 30
#if defined(_WIN32)
#include <intrin.h>
31 32 33
#ifndef NOMINMAX
#define NOMINMAX  // msvc max/min macro conflict with std::min/max
#endif
D
dzhwinter 已提交
34 35 36
#include <windows.h>
#endif  // _WIN32

Y
Yancey1989 已提交
37 38 39
namespace paddle {
namespace operators {
namespace math {
W
weixing02 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
/**
 * SimpleCodeTable class should support 3 functions:
 *
 * size_t size()
 *   return the number of ids
 *
 * int get_max_code_length()
 *   return the maximal code length
 *
 * SimpleCode operator()(size_t i)
 *   return the i-th code. Code class is descriebed below.
 *
 * SimpleCode class should support 3 functions:
 *
 * int get_length()
 *   return the length of the code
 *
 * size_t cal_index(int bit)
 *   bit ranges from 0 to get_length() - 1
 *   return the index for the (1+bit) level parent
 *
 * bool calc_bit(int bit)
 *   return true if the bit level parent is the right child of (1+bit) level
 *   parent
 *
 */
Y
Yancey1989 已提交
66 67 68 69 70 71

/**
 * return the 1-based index of the highest bit set
 *
 * for x > 0:
 * \f[
W
weixing02 已提交
72
 *    FindLastSet(x) = 1 + \floor*{\log_{2}x}
Y
Yancey1989 已提交
73 74
 * \f]
 */
D
dzhwinter 已提交
75
#if !defined(_WIN32)
Y
Yancey1989 已提交
76 77 78 79 80 81
inline constexpr size_t FindLastSet(size_t x) {
  return std::is_same<size_t, unsigned int>::value
             ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
             : (std::is_same<size_t, unsigned long>::value  // NOLINT
                    ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
                    : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
W
wopeizl 已提交
82
}
D
dzhwinter 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
#else
// windows don't have built-in clz, ctz function
template <typename T>
inline int ctz(const T& value) {
  DWORD trailing_zero = 0;
  if (_BitScanForward(&trailing_zero, value)) {
    return static_cast<int>(trailing_zero);
  } else {
    return static_cast<int>(0);
  }
}

template <typename T>
inline int clz(const T& value) {
  DWORD leadning_zero = 0;
  if (_BitScanReverse(&leadning_zero, value)) {
    return static_cast<int>(sizeof(T) * 8 - leadning_zero);
  } else {
    return static_cast<int>(0);
  }
}

105
inline size_t FindLastSet(size_t x) { return 1 + sizeof(size_t) * 8 - clz(x); }
D
dzhwinter 已提交
106
#endif  // !_WIN32
Y
Yu Yang 已提交
107
class SimpleCode {
108 109 110
 public:
  SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
      : c_(static_cast<size_t>(ids[code]) + num_classes) {}
G
guosheng 已提交
111
  /**
T
tianshuo78520a 已提交
112
   * Here the id of root should be 1 rather than 0, thus the encoding of class c
T
tianshuo78520a 已提交
113
   * is `c + num_classes` and all siblings can get the same weight index using
114 115 116 117 118
   * prefixes.
   * Weight index is the prefixes of encoding, thus leave out the right most
   * bit in calc_index.
   * Binary classification path is the suffixes of encoding, thus leave out the
   * left most bit in calc_bit.
G
guosheng 已提交
119
   */
120 121 122
  size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
  bool calc_bit(int bit) const { return c_ & (1 << bit); }
  int get_length() const { return FindLastSet(c_) - 1; }
Y
Yancey1989 已提交
123 124

 private:
125
  size_t c_;
Y
Yancey1989 已提交
126 127
};

J
JiabinYang 已提交
128
template <typename T>
Y
Yu Yang 已提交
129
class CustomCode {
130
 public:
131
  CustomCode(const framework::Tensor& path_table,
132 133
             const framework::Tensor& path_code,
             const int64_t* ids,
134 135 136 137
             int index) {
    seq_len_ = path_table.dims()[1];
    path_table_data_ = path_table.data<T>() + seq_len_ * index;
    path_code_data_ = path_code.data<T>() + seq_len_ * index;
J
JiabinYang 已提交
138
  }
139
  /**
Y
Yu Yang 已提交
140
   * Here the id of root should be 1 rather than 0, thus the encoding of class c
T
tianshuo78520a 已提交
141
   * is `c + num_classes` and all siblings can get the same weight index using
142 143 144 145 146 147
   * prefixes.
   * Weight index is the prefixes of encoding, thus leave out the right most
   * bit in calc_index.
   * Binary classification path is the suffixes of encoding, thus leave out the
   * left most bit in calc_bit.
   */
148 149
  size_t calc_index(int bit) const { return path_table_data_[bit]; }
  bool calc_bit(int bit) const { return path_code_data_[bit]; }
150

Y
Yu Yang 已提交
151
  // NOTE: this function is not thread-safe.
Y
Yu Yang 已提交
152
  int get_length() const {
Y
Yu Yang 已提交
153 154
    if (length_ < 0) {
      auto len = seq_len_;
155 156 157 158 159
      length_ =
          static_cast<int>(std::find_if(path_table_data_,
                                        path_table_data_ + len,
                                        [](const T& val) { return val < 0; }) -
                           path_table_data_);
160
    }
Y
Yu Yang 已提交
161
    return length_;
162 163 164
  }

 private:
Y
Yu Yang 已提交
165
  int64_t seq_len_;
166 167
  const T* path_table_data_;
  const T* path_code_data_;
Y
Yu Yang 已提交
168
  mutable int length_{-1};
169 170
};

Y
Yu Yang 已提交
171
class SimpleCodeTable {
172
 public:
J
JiabinYang 已提交
173
  SimpleCodeTable(size_t num_classes, const int64_t* ids)
174
      : num_classes_(num_classes), ids_(ids) {}
Y
Yu Yang 已提交
175

Y
Yu Yang 已提交
176 177
  SimpleCode get_code(int64_t code) const {
    return SimpleCode(code, num_classes_, ids_);
Y
Yancey1989 已提交
178
  }
Y
Yu Yang 已提交
179

Y
Yancey1989 已提交
180 181 182 183 184
  size_t size() const { return num_classes_; }
  int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }

 private:
  size_t num_classes_;
185 186 187
  const int64_t* ids_;
};

J
JiabinYang 已提交
188
template <typename T>
Y
Yu Yang 已提交
189
class CustomCodeTable {
190
 public:
191
  CustomCodeTable(const framework::Tensor& path_table,
192 193
                  const framework::Tensor& path_code,
                  const int64_t* ids)
194
      : ptable_(path_table), pcode_(path_code), ids_(ids) {}
195

Y
Yu Yang 已提交
196 197
  CustomCode<T> get_code(int64_t code) const {
    return CustomCode<T>(ptable_, pcode_, ids_, code);
198 199
  }

J
JiabinYang 已提交
200
  size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
201
  int get_max_code_length() const {
J
JiabinYang 已提交
202
    return static_cast<size_t>(ptable_.dims()[1]);
203 204 205
  }

 private:
J
JiabinYang 已提交
206 207
  const framework::Tensor& ptable_;
  const framework::Tensor& pcode_;
208
  const int64_t* ids_;
Y
Yancey1989 已提交
209 210
};

R
Ruibiao Chen 已提交
211
using CodeTable = paddle::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
Y
Yu Yang 已提交
212

Y
Yancey1989 已提交
213
template <typename T>
Y
Yancey1989 已提交
214 215
class MatrixBitCodeFunctor {
 public:
J
JiabinYang 已提交
216
  MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
217 218
      : num_classes_(num_classes),
        ids_(ids),
Y
Yu Yang 已提交
219
        code_table_(SimpleCodeTable(num_classes, ids)) {}
220

221
  MatrixBitCodeFunctor(const framework::Tensor& path_table,
222 223
                       const framework::Tensor& path_code,
                       const int64_t* ids)
224
      : num_classes_(static_cast<size_t>(path_table.dims()[1])),
225
        ids_(ids),
226
        code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
Y
Yancey1989 已提交
227 228 229
  /* For j < code_length
       tmat(i, j) += vec(0, index(i, j))
  */
J
JiabinYang 已提交
230
  void Add(const framework::Tensor& vec, framework::Tensor* tmat);
Y
Yancey1989 已提交
231

Y
Yancey1989 已提交
232 233 234
  /* For j < code_length
       vec(0, index(i, j)) += tmat(i, j)
  */
J
JiabinYang 已提交
235
  void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
Y
Yancey1989 已提交
236 237

  /* For j < code_length
Y
Yancey1989 已提交
238
    sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
Y
Yancey1989 已提交
239
  */
J
JiabinYang 已提交
240
  void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
Y
Yancey1989 已提交
241

Y
Yancey1989 已提交
242 243 244
  /* For j < code_length
       tmat(i, j) -= bit(i, j)
  */
J
JiabinYang 已提交
245
  void Sub(framework::Tensor* tmat);
Y
Yancey1989 已提交
246 247 248
  /* For j < code_length
       input.row(i) += tmat(i, j) * weight.row(index(i, j))
  */
249 250
  void Mul(framework::Tensor* tmat,
           const framework::Tensor& weight,
J
JiabinYang 已提交
251
           const framework::Tensor& input);
Y
Yancey1989 已提交
252

Y
Yancey1989 已提交
253 254 255
  /* For index(i, j) >= 0:
      weight.row(index(i, j)) += tmat(i, j) * input.row(i)
  */
256 257
  void MulGradWeight(const framework::Tensor& tmat,
                     framework::Tensor* weight,
J
JiabinYang 已提交
258
                     const framework::Tensor& input);
J
JiabinYang 已提交
259 260 261
  /* For SelectedRows Weight, For index(i, j) >= 0:
      weight.row(index(i, j)) += tmat(i, j) * input.row(i)
  */
262 263
  void MulGradWeight(const framework::Tensor& tmat,
                     phi::SelectedRows* weight,
J
JiabinYang 已提交
264
                     const framework::Tensor& input);
Y
Yancey1989 已提交
265 266 267
  /* For j < code_length
    input.row(i) += tmat(i, j) * weight.row(index(i, j))
  */
J
JiabinYang 已提交
268
  void MulGradError(const framework::Tensor& tmat,
269 270
                    const framework::Tensor& weight,
                    framework::Tensor* input);
W
weixing02 已提交
271

Y
Yancey1989 已提交
272 273
  size_t num_classes_;
  const int64_t* ids_;
Y
Yu Yang 已提交
274
  CodeTable code_table_;
Y
Yancey1989 已提交
275
};
Y
Yancey1989 已提交
276 277 278
}  // namespace math
}  // namespace operators
}  // namespace paddle