matrix_bit_code.h 8.2 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yu Yang 已提交
16
#include <map>
J
JiabinYang 已提交
17 18 19
#include <unordered_map>
#include <utility>
#include <vector>
W
weixing02 已提交
20
#include "paddle/fluid/framework/eigen.h"
J
JiabinYang 已提交
21
#include "paddle/fluid/framework/lod_tensor.h"
22
#include "paddle/fluid/framework/selected_rows_utils.h"
W
weixing02 已提交
23 24
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
Y
Yu Yang 已提交
25
#include "paddle/fluid/platform/variant.h"
26
#include "paddle/phi/kernels/funcs/blas/blas.h"
Y
Yancey1989 已提交
27

D
dzhwinter 已提交
28 29
#if defined(_WIN32)
#include <intrin.h>
30 31 32
#ifndef NOMINMAX
#define NOMINMAX  // msvc max/min macro conflict with std::min/max
#endif
D
dzhwinter 已提交
33 34 35
#include <windows.h>
#endif  // _WIN32

Y
Yancey1989 已提交
36 37 38
namespace paddle {
namespace operators {
namespace math {
W
weixing02 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
/**
 * SimpleCodeTable class should support 3 functions:
 *
 * size_t size()
 *   return the number of ids
 *
 * int get_max_code_length()
 *   return the maximal code length
 *
 * SimpleCode operator()(size_t i)
 *   return the i-th code. Code class is descriebed below.
 *
 * SimpleCode class should support 3 functions:
 *
 * int get_length()
 *   return the length of the code
 *
 * size_t cal_index(int bit)
 *   bit ranges from 0 to get_length() - 1
 *   return the index for the (1+bit) level parent
 *
 * bool calc_bit(int bit)
 *   return true if the bit level parent is the right child of (1+bit) level
 *   parent
 *
 */
Y
Yancey1989 已提交
65 66 67 68 69 70

/**
 * return the 1-based index of the highest bit set
 *
 * for x > 0:
 * \f[
W
weixing02 已提交
71
 *    FindLastSet(x) = 1 + \floor*{\log_{2}x}
Y
Yancey1989 已提交
72 73
 * \f]
 */
D
dzhwinter 已提交
74
#if !defined(_WIN32)
Y
Yancey1989 已提交
75 76 77 78 79 80
inline constexpr size_t FindLastSet(size_t x) {
  return std::is_same<size_t, unsigned int>::value
             ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
             : (std::is_same<size_t, unsigned long>::value  // NOLINT
                    ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
                    : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
W
wopeizl 已提交
81
}
D
dzhwinter 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
#else
// windows don't have built-in clz, ctz function
template <typename T>
inline int ctz(const T& value) {
  DWORD trailing_zero = 0;
  if (_BitScanForward(&trailing_zero, value)) {
    return static_cast<int>(trailing_zero);
  } else {
    return static_cast<int>(0);
  }
}

template <typename T>
inline int clz(const T& value) {
  DWORD leadning_zero = 0;
  if (_BitScanReverse(&leadning_zero, value)) {
    return static_cast<int>(sizeof(T) * 8 - leadning_zero);
  } else {
    return static_cast<int>(0);
  }
}

104
inline size_t FindLastSet(size_t x) { return 1 + sizeof(size_t) * 8 - clz(x); }
D
dzhwinter 已提交
105
#endif  // !_WIN32
Y
Yu Yang 已提交
106
class SimpleCode {
107 108 109
 public:
  SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
      : c_(static_cast<size_t>(ids[code]) + num_classes) {}
G
guosheng 已提交
110
  /**
T
tianshuo78520a 已提交
111
   * Here the id of root should be 1 rather than 0, thus the encoding of class c
T
tianshuo78520a 已提交
112
   * is `c + num_classes` and all siblings can get the same weight index using
113 114 115 116 117
   * prefixes.
   * Weight index is the prefixes of encoding, thus leave out the right most
   * bit in calc_index.
   * Binary classification path is the suffixes of encoding, thus leave out the
   * left most bit in calc_bit.
G
guosheng 已提交
118
   */
119 120 121
  size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
  bool calc_bit(int bit) const { return c_ & (1 << bit); }
  int get_length() const { return FindLastSet(c_) - 1; }
Y
Yancey1989 已提交
122 123

 private:
124
  size_t c_;
Y
Yancey1989 已提交
125 126
};

J
JiabinYang 已提交
127
template <typename T>
Y
Yu Yang 已提交
128
class CustomCode {
129
 public:
130 131 132 133 134 135
  CustomCode(const framework::Tensor& path_table,
             const framework::Tensor& path_code, const int64_t* ids,
             int index) {
    seq_len_ = path_table.dims()[1];
    path_table_data_ = path_table.data<T>() + seq_len_ * index;
    path_code_data_ = path_code.data<T>() + seq_len_ * index;
J
JiabinYang 已提交
136
  }
137
  /**
Y
Yu Yang 已提交
138
   * Here the id of root should be 1 rather than 0, thus the encoding of class c
T
tianshuo78520a 已提交
139
   * is `c + num_classes` and all siblings can get the same weight index using
140 141 142 143 144 145
   * prefixes.
   * Weight index is the prefixes of encoding, thus leave out the right most
   * bit in calc_index.
   * Binary classification path is the suffixes of encoding, thus leave out the
   * left most bit in calc_bit.
   */
146 147
  size_t calc_index(int bit) const { return path_table_data_[bit]; }
  bool calc_bit(int bit) const { return path_code_data_[bit]; }
148

Y
Yu Yang 已提交
149
  // NOTE: this function is not thread-safe.
Y
Yu Yang 已提交
150
  int get_length() const {
Y
Yu Yang 已提交
151 152
    if (length_ < 0) {
      auto len = seq_len_;
153 154 155 156
      length_ = static_cast<int>(
          std::find_if(path_table_data_, path_table_data_ + len,
                       [](const T& val) { return val < 0; }) -
          path_table_data_);
157
    }
Y
Yu Yang 已提交
158
    return length_;
159 160 161
  }

 private:
Y
Yu Yang 已提交
162
  int64_t seq_len_;
163 164
  const T* path_table_data_;
  const T* path_code_data_;
Y
Yu Yang 已提交
165
  mutable int length_{-1};
166 167
};

Y
Yu Yang 已提交
168
class SimpleCodeTable {
169
 public:
J
JiabinYang 已提交
170
  SimpleCodeTable(size_t num_classes, const int64_t* ids)
171
      : num_classes_(num_classes), ids_(ids) {}
Y
Yu Yang 已提交
172

Y
Yu Yang 已提交
173 174
  SimpleCode get_code(int64_t code) const {
    return SimpleCode(code, num_classes_, ids_);
Y
Yancey1989 已提交
175
  }
Y
Yu Yang 已提交
176

Y
Yancey1989 已提交
177 178 179 180 181
  size_t size() const { return num_classes_; }
  int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }

 private:
  size_t num_classes_;
182 183 184
  const int64_t* ids_;
};

J
JiabinYang 已提交
185
template <typename T>
Y
Yu Yang 已提交
186
class CustomCodeTable {
187
 public:
188 189 190
  CustomCodeTable(const framework::Tensor& path_table,
                  const framework::Tensor& path_code, const int64_t* ids)
      : ptable_(path_table), pcode_(path_code), ids_(ids) {}
191

Y
Yu Yang 已提交
192 193
  CustomCode<T> get_code(int64_t code) const {
    return CustomCode<T>(ptable_, pcode_, ids_, code);
194 195
  }

J
JiabinYang 已提交
196
  size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
197
  int get_max_code_length() const {
J
JiabinYang 已提交
198
    return static_cast<size_t>(ptable_.dims()[1]);
199 200 201
  }

 private:
J
JiabinYang 已提交
202 203
  const framework::Tensor& ptable_;
  const framework::Tensor& pcode_;
204
  const int64_t* ids_;
Y
Yancey1989 已提交
205 206
};

Y
Yu Yang 已提交
207 208
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;

Y
Yancey1989 已提交
209
template <typename T>
Y
Yancey1989 已提交
210 211
class MatrixBitCodeFunctor {
 public:
J
JiabinYang 已提交
212
  MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
213 214
      : num_classes_(num_classes),
        ids_(ids),
Y
Yu Yang 已提交
215
        code_table_(SimpleCodeTable(num_classes, ids)) {}
216

217 218 219
  MatrixBitCodeFunctor(const framework::Tensor& path_table,
                       const framework::Tensor& path_code, const int64_t* ids)
      : num_classes_(static_cast<size_t>(path_table.dims()[1])),
220
        ids_(ids),
221
        code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
Y
Yancey1989 已提交
222 223 224
  /* For j < code_length
       tmat(i, j) += vec(0, index(i, j))
  */
J
JiabinYang 已提交
225
  void Add(const framework::Tensor& vec, framework::Tensor* tmat);
Y
Yancey1989 已提交
226

Y
Yancey1989 已提交
227 228 229
  /* For j < code_length
       vec(0, index(i, j)) += tmat(i, j)
  */
J
JiabinYang 已提交
230
  void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
Y
Yancey1989 已提交
231 232

  /* For j < code_length
Y
Yancey1989 已提交
233
    sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
Y
Yancey1989 已提交
234
  */
J
JiabinYang 已提交
235
  void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
Y
Yancey1989 已提交
236

Y
Yancey1989 已提交
237 238 239
  /* For j < code_length
       tmat(i, j) -= bit(i, j)
  */
J
JiabinYang 已提交
240
  void Sub(framework::Tensor* tmat);
Y
Yancey1989 已提交
241 242 243
  /* For j < code_length
       input.row(i) += tmat(i, j) * weight.row(index(i, j))
  */
J
JiabinYang 已提交
244 245
  void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
           const framework::Tensor& input);
Y
Yancey1989 已提交
246

Y
Yancey1989 已提交
247 248 249
  /* For index(i, j) >= 0:
      weight.row(index(i, j)) += tmat(i, j) * input.row(i)
  */
J
JiabinYang 已提交
250 251
  void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
                     const framework::Tensor& input);
J
JiabinYang 已提交
252 253 254
  /* For SelectedRows Weight, For index(i, j) >= 0:
      weight.row(index(i, j)) += tmat(i, j) * input.row(i)
  */
255
  void MulGradWeight(const framework::Tensor& tmat, phi::SelectedRows* weight,
J
JiabinYang 已提交
256
                     const framework::Tensor& input);
Y
Yancey1989 已提交
257 258 259
  /* For j < code_length
    input.row(i) += tmat(i, j) * weight.row(index(i, j))
  */
J
JiabinYang 已提交
260 261
  void MulGradError(const framework::Tensor& tmat,
                    const framework::Tensor& weight, framework::Tensor* input);
W
weixing02 已提交
262

Y
Yancey1989 已提交
263 264
  size_t num_classes_;
  const int64_t* ids_;
Y
Yu Yang 已提交
265
  CodeTable code_table_;
Y
Yancey1989 已提交
266
};
Y
Yancey1989 已提交
267 268 269
}  // namespace math
}  // namespace operators
}  // namespace paddle