matrix_bit_code.cc 6.1 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
weixing02 已提交
15 16
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
Y
Yancey1989 已提交
17 18 19 20
namespace paddle {
namespace operators {
namespace math {

Y
Yancey1989 已提交
21
template <typename T>
W
weixing02 已提交
22
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
Y
Yancey1989 已提交
23 24
                                  const framework::Tensor& vec) {
  SimpleCodeTable code_table(num_classes_);
W
weixing02 已提交
25 26
  size_t batch_size = tmat->dims()[0];
  size_t width = tmat->dims()[1];
Y
Yancey1989 已提交
27 28
  for (size_t i = 0; i < batch_size; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
29
    int code_length = code.get_length();
Y
Yancey1989 已提交
30
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
31
      size_t index = code.calc_index(j);
W
weixing02 已提交
32
      tmat->data<T>()[i * width + j] += vec.data<T>()[index];
Y
Yancey1989 已提交
33 34 35 36
    }
  }
}

Y
Yancey1989 已提交
37
template <typename T>
W
weixing02 已提交
38 39
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
                                      framework::Tensor* vec) {
Y
Yancey1989 已提交
40 41 42 43 44
  SimpleCodeTable code_table(num_classes_);
  size_t batch_size = tmat.dims()[0];
  size_t width = tmat.dims()[1];
  for (size_t i = 0; i < batch_size; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
45 46
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
47
      size_t index = code.calc_index(j);
W
weixing02 已提交
48
      vec->data<T>()[index] += tmat.data<T>()[i * width + j];
Y
Yancey1989 已提交
49 50
    }
  }
Y
Yancey1989 已提交
51 52
}

Y
Yancey1989 已提交
53
template <typename T>
W
weixing02 已提交
54 55
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
                                  framework::Tensor* sum, T scale_sum) {
Y
Yancey1989 已提交
56
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
57 58 59
  size_t num_samples = tmat.dims()[0];
  size_t o_width = tmat.dims()[1];
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
60
    T sm = static_cast<T>(0.0);
Y
Yancey1989 已提交
61
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
62 63 64
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      if (code.calc_bit(j)) {
65 66
        // calc_bit starts from right most bit, while data in tmat[i] is in the
        // reverse order.
Y
Yancey1989 已提交
67 68 69
        sm += tmat.data<T>()[i * o_width + j];
      }
    }
W
weixing02 已提交
70
    sum->data<T>()[i] = scale_sum * sm;
Y
Yancey1989 已提交
71 72
  }
}
Y
Yancey1989 已提交
73

Y
Yancey1989 已提交
74
template <typename T>
W
weixing02 已提交
75
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
Y
Yancey1989 已提交
76 77
                                  const framework::Tensor& weight,
                                  const framework::Tensor& input) {
Y
Yancey1989 已提交
78
  SimpleCodeTable code_table(num_classes_);
W
weixing02 已提交
79 80
  size_t num_samples = tmat->dims()[0];
  size_t tmat_width = tmat->dims()[1];
Y
Yancey1989 已提交
81
  size_t input_width = input.dims()[1];
W
weixing02 已提交
82 83
  size_t weight_width = weight.dims()[1];
  auto tmat_value = tmat->data<T>();
Y
Yancey1989 已提交
84 85
  auto weight_value = weight.data<T>();
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
86
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
87
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
88 89 90
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
91 92
      T sum = static_cast<T>(0.0);
      for (size_t k = 0; k < input_width; ++k) {
Y
Yancey1989 已提交
93 94
        sum += weight_value[weight_width * index + k] *
               input_value[input_width * i + k];
Y
Yancey1989 已提交
95
      }
Y
Yancey1989 已提交
96
      tmat_value[i * tmat_width + j] += sum;
Y
Yancey1989 已提交
97 98 99 100 101
    }
  }
}

template <typename T>
Y
Yancey1989 已提交
102
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
W
weixing02 已提交
103
                                            framework::Tensor* weight,
Y
Yancey1989 已提交
104
                                            const framework::Tensor& input) {
Y
Yancey1989 已提交
105
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
106 107
  size_t num_samples = tmat.dims()[0];
  size_t input_width = input.dims()[1];
W
weixing02 已提交
108 109
  size_t tmat_width = tmat.dims()[1];
  size_t weight_width = weight->dims()[1];
Y
Yancey1989 已提交
110
  auto tmat_value = tmat.data<T>();
W
weixing02 已提交
111
  auto weight_value = weight->data<T>();
Y
Yancey1989 已提交
112
  auto input_value = input.data<T>();
Y
Yancey1989 已提交
113
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
114
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
115 116 117
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      size_t index = code.calc_index(j);
Y
Yancey1989 已提交
118

Y
Yancey1989 已提交
119
      for (size_t k = 0; k < input_width; ++k) {
W
weixing02 已提交
120 121
        weight_value[weight_width * index + k] +=
            tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
Y
Yancey1989 已提交
122
      }
Y
Yancey1989 已提交
123
    }
Y
Yancey1989 已提交
124
  }
Y
Yancey1989 已提交
125
}
Y
Yancey1989 已提交
126 127

template <typename T>
Y
Yancey1989 已提交
128
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
Y
Yancey1989 已提交
129
                                           const framework::Tensor& weight,
W
weixing02 已提交
130
                                           framework::Tensor* input) {
Y
Yancey1989 已提交
131
  SimpleCodeTable code_table(num_classes_);
Y
Yancey1989 已提交
132
  size_t num_samples = tmat.dims()[0];
W
weixing02 已提交
133 134
  size_t tmat_width = tmat.dims()[1];
  size_t input_width = input->dims()[1];
Y
Yancey1989 已提交
135
  size_t weight_width = weight.dims()[1];
Y
Yancey1989 已提交
136 137
  auto tmat_value = tmat.data<T>();
  auto weight_value = weight.data<T>();
W
weixing02 已提交
138
  auto input_value = input->data<T>();
Y
Yancey1989 已提交
139

Y
Yancey1989 已提交
140
  for (size_t i = 0; i < num_samples; ++i) {
Y
Yancey1989 已提交
141
    auto code = code_table(static_cast<size_t>(ids_[i]));
Y
Yancey1989 已提交
142 143
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
Y
Yancey1989 已提交
144 145 146
      size_t index = code.calc_index(j);

      for (size_t k = 0; k < input_width; ++k) {
W
weixing02 已提交
147 148 149
        input_value[input_width * i + k] +=
            tmat_value[i * tmat_width + j] *
            weight_value[weight_width * index + k];
Y
Yancey1989 已提交
150 151 152 153 154 155
      }
    }
  }
}

template <typename T>
W
weixing02 已提交
156
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
Y
Yancey1989 已提交
157
  SimpleCodeTable code_table(num_classes_);
W
weixing02 已提交
158 159
  size_t num_samples = tmat->dims()[0];
  size_t o_width = tmat->dims()[1];
Y
Yancey1989 已提交
160 161 162 163 164
  for (size_t i = 0; i < num_samples; ++i) {
    auto code = code_table(static_cast<size_t>(ids_[i]));
    int code_length = code.get_length();
    for (int j = 0; j < code_length; ++j) {
      if (code.calc_bit(j)) {
W
weixing02 已提交
165
        tmat->data<T>()[i * o_width + j] -= 1;
Y
Yancey1989 已提交
166 167 168
      }
    }
  }
Y
Yancey1989 已提交
169 170
}

Y
Yancey1989 已提交
171 172 173
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;

Y
Yancey1989 已提交
174 175 176
}  // namespace math
}  // namespace operators
}  // namespace paddle