matrix_bit_code.cc 6.0 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
weixing02 已提交
15 16
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
Y
Yancey1989 已提交
17 18 19 20
namespace paddle {
namespace operators {
namespace math {

Y
Yancey1989 已提交
21
template <typename T>
W
weixing02 已提交
22
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
Y
Yancey1989 已提交
23 24
                                  const framework::Tensor& vec) {
  SimpleCodeTable code_table(num_classes_);
W
weixing02 已提交
25 26
  size_t batch_size = tmat->dims()[0];
  size_t width = tmat->dims()[1];
Y
Yancey1989 已提交
27 28
  for (size_t i = 0; i < batch_size; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
29
    int code_length = code.get_length();
Y
Yancey1989 已提交
30
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
31
      size_t index = code.calc_index(j);
W
weixing02 已提交
32
      tmat->data<T>()[i * width + j] += vec.data<T>()[index];
Y
Yancey1989 已提交
33 34 35 36
    }
  }
}

Y
Yancey1989 已提交
37
template <typename T>
W
weixing02 已提交
38 39
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
                                      framework::Tensor* vec) {
Y
Yancey1989 已提交
40 41 42 43 44
  SimpleCodeTable code_table(num_classes_);
  size_t batch_size = tmat.dims()[0];
  size_t width = tmat.dims()[1];
  for (size_t i = 0; i < batch_size; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
45 46
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
47
      size_t index = code.calc_index(j);
W
weixing02 已提交
48
      vec->data<T>()[index] += tmat.data<T>()[i * width + j];
Y
Yancey1989 已提交
49 50
    }
  }
Y
Yancey1989 已提交
51 52
}

Y
Yancey1989 已提交
53
template <typename T>
W
weixing02 已提交
54 55
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
                                  framework::Tensor* sum, T scale_sum) {
Y
Yancey1989 已提交
56
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
57 58 59
  size_t num_samples = tmat.dims()[0];
  size_t o_width = tmat.dims()[1];
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
60
    T sm = static_cast<T>(0.0);
Y
Yancey1989 已提交
61
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
62 63 64 65 66 67
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      if (code.calc_bit(j)) {
        sm += tmat.data<T>()[i * o_width + j];
      }
    }
W
weixing02 已提交
68
    sum->data<T>()[i] = scale_sum * sm;
Y
Yancey1989 已提交
69 70
  }
}
Y
Yancey1989 已提交
71

Y
Yancey1989 已提交
72
template <typename T>
W
weixing02 已提交
73
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
Y
Yancey1989 已提交
74 75
                                  const framework::Tensor& weight,
                                  const framework::Tensor& input) {
Y
Yancey1989 已提交
76
  SimpleCodeTable code_table(num_classes_);
W
weixing02 已提交
77 78
  size_t num_samples = tmat->dims()[0];
  size_t tmat_width = tmat->dims()[1];
Y
Yancey1989 已提交
79
  size_t input_width = input.dims()[1];
W
weixing02 已提交
80 81
  size_t weight_width = weight.dims()[1];
  auto tmat_value = tmat->data<T>();
Y
Yancey1989 已提交
82 83
  auto weight_value = weight.data<T>();
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
84
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
85
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
86 87 88
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
89 90
      T sum = static_cast<T>(0.0);
      for (size_t k = 0; k < input_width; ++k) {
Y
Yancey1989 已提交
91 92
        sum += weight_value[weight_width * index + k] *
               input_value[input_width * i + k];
Y
Yancey1989 已提交
93
      }
Y
Yancey1989 已提交
94
      tmat_value[i * tmat_width + j] += sum;
Y
Yancey1989 已提交
95 96 97 98 99
    }
  }
}

template <typename T>
Y
Yancey1989 已提交
100
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
W
weixing02 已提交
101
                                            framework::Tensor* weight,
Y
Yancey1989 已提交
102
                                            const framework::Tensor& input) {
Y
Yancey1989 已提交
103
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
104 105
  size_t num_samples = tmat.dims()[0];
  size_t input_width = input.dims()[1];
W
weixing02 已提交
106 107
  size_t tmat_width = tmat.dims()[1];
  size_t weight_width = weight->dims()[1];
Y
Yancey1989 已提交
108
  auto tmat_value = tmat.data<T>();
W
weixing02 已提交
109
  auto weight_value = weight->data<T>();
Y
Yancey1989 已提交
110
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
111
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
112
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
113 114 115
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
116

Y
Yancey1989 已提交
117
      for (size_t k = 0; k < input_width; ++k) {
W
weixing02 已提交
118 119
        weight_value[weight_width * index + k] +=
            tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
Y
Yancey1989 已提交
120
      }
Y
Yancey1989 已提交
121
    }
Y
Yancey1989 已提交
122
  }
Y
Yancey1989 已提交
123
}
Y
Yancey1989 已提交
124 125

template <typename T>
Y
Yancey1989 已提交
126
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
Y
Yancey1989 已提交
127
                                           const framework::Tensor& weight,
W
weixing02 已提交
128
                                           framework::Tensor* input) {
Y
Yancey1989 已提交
129
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
130
  size_t num_samples = tmat.dims()[0];
W
weixing02 已提交
131 132
  size_t tmat_width = tmat.dims()[1];
  size_t input_width = input->dims()[1];
Y
Yancey1989 已提交
133
  size_t weight_width = weight.dims()[1];
Y
Yancey1989 已提交
134 135
  auto tmat_value = tmat.data<T>();
  auto weight_value = weight.data<T>();
W
weixing02 已提交
136
  auto input_value = input->data<T>();
Y
Yancey1989 已提交
137

Y
Yancey1989 已提交
138
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
139
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
140 141
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
142 143 144
      size_t index = code.calc_index(j);

      for (size_t k = 0; k < input_width; ++k) {
W
weixing02 已提交
145 146 147
        input_value[input_width * i + k] +=
            tmat_value[i * tmat_width + j] *
            weight_value[weight_width * index + k];
Y
Yancey1989 已提交
148 149 150 151 152 153
      }
    }
  }
}

template <typename T>
W
weixing02 已提交
154
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
Y
Yancey1989 已提交
155
  SimpleCodeTable code_table(num_classes_);
W
weixing02 已提交
156 157
  size_t num_samples = tmat->dims()[0];
  size_t o_width = tmat->dims()[1];
Y
Yancey1989 已提交
158 159 160 161 162
  for (size_t i = 0; i < num_samples; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      if (code.calc_bit(j)) {
W
weixing02 已提交
163
        tmat->data<T>()[i * o_width + j] -= 1;
Y
Yancey1989 已提交
164 165 166
      }
    }
  }
Y
Yancey1989 已提交
167 168
}

Y
Yancey1989 已提交
169 170 171
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;

Y
Yancey1989 已提交
172 173 174
}  // namespace math
}  // namespace operators
}  // namespace paddle