diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 88279f8d8a781ac3a7291572b40392cd0a7d17e0..090c0cca366074958c5189e0d203116cc36fd68d 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -119,6 +119,33 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, } } +// template +// void MatrixBitCodeFunctor::MulGradSparseWeight(const framework::Tensor& +// tmat, +// framework::SelectedRows* weight, +// const framework::Tensor& input) { +// size_t num_samples = tmat.dims()[0]; +// size_t input_width = input.dims()[1]; +// size_t tmat_width = tmat.dims()[1]; +// size_t weight_width = weight->dims()[1]; +// auto tmat_value = tmat.data(); +// auto weight_value = weight->data(); +// auto input_value = input.data(); +// for (size_t i = 0; i < num_samples; ++i) { +// auto code = code_table->get_code(i); +// int code_length = code->get_length(); +// for (int j = 0; j < code_length; ++j) { +// // size_t index = code->calc_index(j); + +// for (size_t k = 0; k < input_width; ++k) { +// weight_value[j * weight_width + k] += +// tmat_value[i * tmat_width + j] * input_value[input_width * i + +// k]; +// } +// } +// } +// } + template void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight,