提交 ba9ff508 编写于 作者: J JiabinYang

temp fix

上级 a507845a
...@@ -119,6 +119,33 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -119,6 +119,33 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
} }
} }
// template <typename T>
// void MatrixBitCodeFunctor<T>::MulGradSparseWeight(const framework::Tensor&
// tmat,
// framework::SelectedRows* weight,
// const framework::Tensor& input) {
// size_t num_samples = tmat.dims()[0];
// size_t input_width = input.dims()[1];
// size_t tmat_width = tmat.dims()[1];
// size_t weight_width = weight->dims()[1];
// auto tmat_value = tmat.data<T>();
// auto weight_value = weight->data<T>();
// auto input_value = input.data<T>();
// for (size_t i = 0; i < num_samples; ++i) {
// auto code = code_table->get_code(i);
// int code_length = code->get_length();
// for (int j = 0; j < code_length; ++j) {
// // size_t index = code->calc_index(j);
// for (size_t k = 0; k < input_width; ++k) {
// weight_value[j * weight_width + k] +=
// tmat_value[i * tmat_width + j] * input_value[input_width * i +
// k];
// }
// }
// }
// }
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, const framework::Tensor& weight,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册