matrix_bit_code.cc 6.5 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "matrix_bit_code.h"

namespace paddle {
namespace operators {
namespace math {

/**
 * CodeTable class should support 3 functions:
 *
 * size_t size()
Y
Yancey1989 已提交
25
 *   return the number of ids
Y
Yancey1989 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 *
 * int getMaxCodeLength()
 *   return the maximal code length
 *
 * Code operator()(size_t i)
 *   return the i-th code. Code class is descriebed below.
 *
 * Code class should support 3 functions:
 *
 * int getLength()
 *   return the length of the code
 *
 * bool calcIndex(int bit)
 *   bit ranges from 0 to getLength() - 1
 *   return the index for the (1+bit) level parent
 *
 * bool calcBit(int bit)
 *   return true if the bit level parent is the right child of (1+bit) level
 *   parent
 *
 */

Y
Yancey1989 已提交
48 49 50 51 52 53 54 55
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor& tmat,
                                  const framework::Tensor& vec) {
  SimpleCodeTable code_table(num_classes_);
  size_t batch_size = tmat.dims()[0];
  size_t width = tmat.dims()[1];
  for (size_t i = 0; i < batch_size; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
56
    int code_length = code.get_length();
Y
Yancey1989 已提交
57
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
58
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
59
      tmat.data<T>()[i * width + j] += vec.data<T>()[index];
Y
Yancey1989 已提交
60 61 62 63
    }
  }
}

Y
Yancey1989 已提交
64 65 66 67 68 69 70 71
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(framework::Tensor& tmat,
                                      framework::Tensor& vec) {
  SimpleCodeTable code_table(num_classes_);
  size_t batch_size = tmat.dims()[0];
  size_t width = tmat.dims()[1];
  for (size_t i = 0; i < batch_size; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
72 73
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
74 75
      size_t index = code.calc_index(j);
      vec.data<T>()[index] += tmat.data<T>()[i * width + j];
Y
Yancey1989 已提交
76 77
    }
  }
Y
Yancey1989 已提交
78 79
}

Y
Yancey1989 已提交
80 81 82 83
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(framework::Tensor& tmat,
                                  framework::Tensor& sum, T scale_sum) {
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
84 85 86
  size_t num_samples = tmat.dims()[0];
  size_t o_width = tmat.dims()[1];
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
87
    T sm = static_cast<T>(0.0);
Y
Yancey1989 已提交
88
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
89 90 91 92 93 94 95 96 97
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      if (code.calc_bit(j)) {
        sm += tmat.data<T>()[i * o_width + j];
      }
    }
    sum.data<T>()[i] = scale_sum * sm;
  }
}
Y
Yancey1989 已提交
98

Y
Yancey1989 已提交
99
template <typename T>
Y
Yancey1989 已提交
100
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor& tmat,
Y
Yancey1989 已提交
101 102
                                  const framework::Tensor& weight,
                                  const framework::Tensor& input) {
Y
Yancey1989 已提交
103
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
104 105 106
  size_t num_samples = tmat.dims()[0];
  size_t tmat_width = tmat.dims()[1];
  size_t input_width = input.dims()[1];
Y
Yancey1989 已提交
107 108 109 110
  size_t weight_width = weight.dims()[2];
  auto tmat_value = tmat.data<T>();
  auto weight_value = weight.data<T>();
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
111
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
112
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
113 114 115
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
116 117 118

      T sum = static_cast<T>(0.0);
      for (size_t k = 0; k < input_width; ++k) {
Y
Yancey1989 已提交
119 120
        sum += weight_value[weight_width * index + k] *
               input_value[input_width * i + k];
Y
Yancey1989 已提交
121
      }
Y
Yancey1989 已提交
122
      tmat_value[i * tmat_width + j] += sum;
Y
Yancey1989 已提交
123 124 125 126 127
    }
  }
}

template <typename T>
Y
Yancey1989 已提交
128
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
Y
Yancey1989 已提交
129 130
                                            framework::Tensor& weight,
                                            const framework::Tensor& input) {
Y
Yancey1989 已提交
131
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
132 133 134
  size_t num_samples = tmat.dims()[0];
  size_t input_width = input.dims()[1];
  size_t weight_width = weight.dims()[1];
Y
Yancey1989 已提交
135 136 137
  auto tmat_value = tmat.data<T>();
  auto weight_value = weight.data<T>();
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
138
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
139
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
140 141 142
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
143

Y
Yancey1989 已提交
144
      for (size_t k = 0; k < input_width; ++k) {
Y
Yancey1989 已提交
145 146
        weight_value[weight_width * index * k] +=
            tmat_value[i * weight_width * j] * input_value[input_width * i + k];
Y
Yancey1989 已提交
147
      }
Y
Yancey1989 已提交
148
    }
Y
Yancey1989 已提交
149
  }
Y
Yancey1989 已提交
150
}
Y
Yancey1989 已提交
151 152

template <typename T>
Y
Yancey1989 已提交
153
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
Y
Yancey1989 已提交
154 155
                                           const framework::Tensor& weight,
                                           framework::Tensor& input) {
Y
Yancey1989 已提交
156
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
157
  size_t num_samples = tmat.dims()[0];
Y
Yancey1989 已提交
158 159
  size_t input_width = input.dims()[1];
  size_t weight_width = weight.dims()[1];
Y
Yancey1989 已提交
160 161 162
  auto tmat_value = tmat.data<T>();
  auto weight_value = weight.data<T>();
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
163

Y
Yancey1989 已提交
164
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
165
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
166 167
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
168 169 170
      size_t index = code.calc_index(j);

      for (size_t k = 0; k < input_width; ++k) {
Y
Yancey1989 已提交
171 172 173
        input_value[weight_width * index * k] +=
            tmat_value[i * weight_width * j] *
            weight_value[weight_width * i + k];
Y
Yancey1989 已提交
174 175 176 177 178 179
      }
    }
  }
}

template <typename T>
Y
Yancey1989 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor& tmat) {
  SimpleCodeTable code_table(num_classes_);
  size_t num_samples = tmat.dims()[0];
  size_t o_width = tmat.dims()[1];
  for (size_t i = 0; i < num_samples; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      if (code.calc_bit(j)) {
        tmat.data<T>()[i * o_width + j] -= 1;
      }
    }
  }
Y
Yancey1989 已提交
193 194
}

Y
Yancey1989 已提交
195 196 197
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;

Y
Yancey1989 已提交
198 199 200
}  // namespace math
}  // namespace operators
}  // namespace paddle