diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake index a3599dd798c07f57ed82e3f25b6bb9fc4f8bdc3a..52ad02a3551bfe0ccf4afa5eea6ae153dfc107b4 100644 --- a/cmake/external/python.cmake +++ b/cmake/external/python.cmake @@ -18,8 +18,8 @@ ENDIF() INCLUDE(python_module) -FIND_PACKAGE(PythonInterp ${PY_VERSION}) -FIND_PACKAGE(PythonLibs ${PY_VERSION}) +FIND_PACKAGE(PythonInterp ${PY_VERSION} REQUIRED) +FIND_PACKAGE(PythonLibs ${PY_VERSION} REQUIRED) if(WIN32) execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" @@ -79,6 +79,6 @@ IF(PYTHONINTERP_FOUND) "please use pip to upgrade protobuf. pip install -U protobuf") ENDIF() ENDIF(PYTHONINTERP_FOUND) - +message(STATUS ${PYTHON_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR}) diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 92affa0e4ed762cc660baf0f84fd62f13ee3de29..d55e832cc2d9a4a5e2cb7fe5cf451a1205601951 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -15,225 +15,379 @@ limitations under the License. */ #include "paddle/fluid/operators/math/matrix_bit_code.h" #include #include + namespace paddle { namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(const framework::Tensor& vec, - framework::Tensor* tmat) { - size_t batch_size = tmat->dims()[0]; - size_t width = tmat->dims()[1]; - auto* tmat_data = tmat->data(); - auto* vec_data = vec.data(); - for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - tmat_data[i * width + j] += vec_data[index]; +struct MatrixBitCodeFunctorAdd : public boost::static_visitor { + const framework::Tensor &vec_; + framework::Tensor *tmat_; + + MatrixBitCodeFunctorAdd(const framework::Tensor &vec, framework::Tensor *tmat) + : vec_(vec), tmat_(tmat) {} + + template + void operator()(const CodeTable &code_table) { + size_t batch_size = tmat_->dims()[0]; + size_t width = tmat_->dims()[1]; + auto *tmat_data = tmat_->data(); + auto *vec_data = vec_.data(); + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + tmat_data[i * width + j] += vec_data[index]; + } } } +}; + +template +void MatrixBitCodeFunctor::Add(const framework::Tensor &vec, + framework::Tensor *tmat) { + MatrixBitCodeFunctorAdd func(vec, tmat); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, - framework::Tensor* vec) { - size_t batch_size = tmat.dims()[0]; - size_t width = tmat.dims()[1]; - auto* vec_data = vec->data(); - auto* tmat_data = tmat.data(); - for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - vec_data[index] += tmat_data[i * width + j]; +struct MatrixBitCodeFunctorAddGrad : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::Tensor *vec_; + MatrixBitCodeFunctorAddGrad(const framework::Tensor &tmat, + framework::Tensor *vec) + : tmat_(tmat), vec_(vec) {} + + template + void operator()(const CodeTable &table) { + size_t batch_size = tmat_.dims()[0]; + size_t width = tmat_.dims()[1]; + auto *vec_data = vec_->data(); + auto *tmat_data = tmat_.data(); + for (size_t i = 0; i < batch_size; ++i) { + auto code = table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + vec_data[index] += tmat_data[i * width + j]; + } } } +}; + +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor &tmat, + framework::Tensor *vec) { + MatrixBitCodeFunctorAddGrad func(tmat, vec); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, - framework::SelectedRows* vec) { - size_t batch_size = tmat.dims()[0]; - size_t width = tmat.dims()[1]; - auto* vec_data = vec->mutable_value()->data(); - auto* tmat_data = tmat.data(); - for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - int64_t row_index = vec->GetIndexFromId(static_cast(index)); - vec_data[row_index] += tmat_data[i * width + j]; +struct MatrixBitCodeFunctorSelectedRowsAddGrad + : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::SelectedRows *vec_; + + MatrixBitCodeFunctorSelectedRowsAddGrad(const framework::Tensor &tmat, + framework::SelectedRows *vec) + : tmat_(tmat), vec_(vec) {} + + template + void operator()(const CodeTable &code_table) { + size_t batch_size = tmat_.dims()[0]; + size_t width = tmat_.dims()[1]; + auto *vec_data = vec_->mutable_value()->template data(); + auto *tmat_data = tmat_.data(); + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + int64_t row_index = vec_->GetIndexFromId(static_cast(index)); + vec_data[row_index] += tmat_data[i * width + j]; + } } } +}; + +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor &tmat, + framework::SelectedRows *vec) { + MatrixBitCodeFunctorSelectedRowsAddGrad func(tmat, vec); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, - framework::Tensor* sum, T scale_sum) { - size_t num_samples = tmat.dims()[0]; - size_t o_width = tmat.dims()[1]; - auto* tmat_data = tmat.data(); - auto* sum_data = sum->data(); - for (size_t i = 0; i < num_samples; ++i) { - T sm = static_cast(0.0); - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - if (code->calc_bit(j)) { - // calc_bit starts from right most bit, while data in tmat[i] is in the - // reverse order. - sm += tmat_data[i * o_width + j]; +struct MatrixBitCodeFunctorSum : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::Tensor *sum_; + T scale_sum_; + + MatrixBitCodeFunctorSum(const framework::Tensor &tmat, framework::Tensor *sum, + T scale_sum) + : tmat_(tmat), sum_(sum), scale_sum_(scale_sum) {} + + template + void operator()(const CodeTable &code_table) { + size_t num_samples = tmat_.dims()[0]; + size_t o_width = tmat_.dims()[1]; + auto *tmat_data = tmat_.data(); + auto *sum_data = sum_->data(); + for (size_t i = 0; i < num_samples; ++i) { + T sm = static_cast(0.0); + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + // calc_bit starts from right most bit, while data in tmat[i] is in + // the + // reverse order. + sm += tmat_data[i * o_width + j]; + } } + sum_data[i] = scale_sum_ * sm; } - sum_data[i] = scale_sum * sm; } +}; + +template +void MatrixBitCodeFunctor::Sum(const framework::Tensor &tmat, + framework::Tensor *sum, T scale_sum) { + MatrixBitCodeFunctorSum func(tmat, sum, scale_sum); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, - const framework::Tensor& weight, - const framework::Tensor& input) { - auto blas = - GetBlas(platform::CPUDeviceContext()); - size_t num_samples = tmat->dims()[0]; - size_t tmat_width = tmat->dims()[1]; - size_t input_width = input.dims()[1]; - size_t weight_width = weight.dims()[1]; - auto tmat_value = tmat->data(); - auto weight_value = weight.data(); - auto input_value = input.data(); - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - const T* input_row = input_value + input_width * i; - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - const T* weight_row = weight_value + weight_width * index; - T sum = static_cast(0.0); - sum = blas.DOT(input_width, weight_row, input_row); - tmat_value[i * tmat_width + j] += sum; +struct MatrixBitCodeFunctorMul : public boost::static_visitor { + framework::Tensor *tmat_; + const framework::Tensor &weight_; + const framework::Tensor &input_; + + MatrixBitCodeFunctorMul(framework::Tensor *tmat, + const framework::Tensor &weight, + const framework::Tensor &input) + : tmat_(tmat), weight_(weight), input_(input) {} + + template + void operator()(const CodeTable &code_table) { + auto blas = + GetBlas(platform::CPUDeviceContext()); + size_t num_samples = tmat_->dims()[0]; + size_t tmat_width = tmat_->dims()[1]; + size_t input_width = input_.dims()[1]; + size_t weight_width = weight_.dims()[1]; + auto tmat_value = tmat_->data(); + auto weight_value = weight_.data(); + auto input_value = input_.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + const T *input_row = input_value + input_width * i; + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + const T *weight_row = weight_value + weight_width * index; + T sum = blas.DOT(input_width, weight_row, input_row); + tmat_value[i * tmat_width + j] += sum; + } } } +}; + +template +void MatrixBitCodeFunctor::Mul(framework::Tensor *tmat, + const framework::Tensor &weight, + const framework::Tensor &input) { + MatrixBitCodeFunctorMul func(tmat, weight, input); + code_table_.apply_visitor(func); } +template +class ReservedVector : public std::vector { + public: + ReservedVector() { this->reserve(N); } +}; + template -void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, - framework::Tensor* weight, - const framework::Tensor& input) { - auto blas = - GetBlas(platform::CPUDeviceContext()); - size_t num_samples = tmat.dims()[0]; - size_t input_width = input.dims()[1]; - size_t tmat_width = tmat.dims()[1]; - size_t weight_width = weight->dims()[1]; - auto tmat_value = tmat.data(); - auto weight_value = weight->data(); - auto input_value = input.data(); - - std::map>> ops; - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - const T* input_value_row = input_value + input_width * i; - const T* tmat_row = tmat_value + i * tmat_width; - for (int j = 0; j < code_length; ++j) { - ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); +struct MatrixBitCodeFunctorMulGradWeight : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::Tensor *weight_; + const framework::Tensor &input_; + MatrixBitCodeFunctorMulGradWeight(const framework::Tensor &tmat, + framework::Tensor *weight, + const framework::Tensor &input) + : tmat_(tmat), weight_(weight), input_(input) {} + template + void operator()(const CodeTable &code_table) { + auto blas = + GetBlas(platform::CPUDeviceContext()); + size_t num_samples = tmat_.dims()[0]; + size_t input_width = input_.dims()[1]; + size_t tmat_width = tmat_.dims()[1]; + size_t weight_width = weight_->dims()[1]; + auto tmat_value = tmat_.data(); + auto weight_value = weight_->data(); + auto input_value = input_.data(); + + std::map, 8u>> ops; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + const T *input_value_row = input_value + input_width * i; + const T *tmat_row = tmat_value + i * tmat_width; + for (int j = 0; j < code_length; ++j) { + ops[code.calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } } - } - for (auto& op : ops) { - auto& op_in_row = op.second; - for (auto& pair : op_in_row) { - auto& scale = pair.first; - auto* input_row = pair.second; - T* weight_row = weight_value + op.first * weight_width; - blas.AXPY(input_width, scale, input_row, weight_row); + for (auto &op : ops) { + auto &op_in_row = op.second; + for (auto &pair : op_in_row) { + auto &scale = pair.first; + auto *input_row = pair.second; + T *weight_row = weight_value + op.first * weight_width; + blas.AXPY(input_width, scale, input_row, weight_row); + } } } +}; + +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor &tmat, + framework::Tensor *weight, + const framework::Tensor &input) { + MatrixBitCodeFunctorMulGradWeight func(tmat, weight, input); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, - framework::SelectedRows* weight, - const framework::Tensor& input) { - auto blas = - GetBlas(platform::CPUDeviceContext()); - size_t num_samples = tmat.dims()[0]; - size_t input_width = input.dims()[1]; - size_t tmat_width = tmat.dims()[1]; - size_t weight_width = weight->value().dims()[1]; - auto tmat_value = tmat.data(); - auto weight_value = weight->mutable_value()->data(); - auto input_value = input.data(); - - std::unordered_map>> ops; - ops.reserve(weight->rows().size()); - - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - const T* input_value_row = input_value + input_width * i; - const T* tmat_row = tmat_value + i * tmat_width; - for (int j = 0; j < code_length; ++j) { - ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); +struct MatrixBitCodeFunctorMulGradWeightSR + : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::SelectedRows *weight_; + const framework::Tensor &input_; + + MatrixBitCodeFunctorMulGradWeightSR(const framework::Tensor &tmat, + framework::SelectedRows *weight, + const framework::Tensor &input) + : tmat_(tmat), weight_(weight), input_(input) {} + + template + void operator()(const CodeTable &code_table) { + auto blas = + GetBlas(platform::CPUDeviceContext()); + size_t num_samples = tmat_.dims()[0]; + size_t input_width = input_.dims()[1]; + size_t tmat_width = tmat_.dims()[1]; + size_t weight_width = weight_->value().dims()[1]; + auto tmat_value = tmat_.data(); + auto weight_value = weight_->mutable_value()->data(); + auto input_value = input_.data(); + + std::unordered_map>> ops; + ops.reserve(weight_->rows().size()); + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + const T *input_value_row = input_value + input_width * i; + const T *tmat_row = tmat_value + i * tmat_width; + for (int j = 0; j < code_length; ++j) { + ops[code.calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } } - } - for (auto& row : weight->rows()) { - auto& op_in_row = ops[row]; - for (auto& pair : op_in_row) { - auto& scale = pair.first; - auto* input_row = pair.second; - blas.AXPY(input_width, scale, input_row, weight_value); + for (auto &row : weight_->rows()) { + auto &op_in_row = ops[row]; + for (auto &pair : op_in_row) { + auto &scale = pair.first; + auto *input_row = pair.second; + blas.AXPY(input_width, scale, input_row, weight_value); + } + weight_value += weight_width; } - weight_value += weight_width; } +}; + +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor &tmat, + framework::SelectedRows *weight, + const framework::Tensor &input) { + MatrixBitCodeFunctorMulGradWeightSR func(tmat, weight, input); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor* input) { - size_t num_samples = tmat.dims()[0]; - size_t tmat_width = tmat.dims()[1]; - size_t input_width = input->dims()[1]; - size_t weight_width = weight.dims()[1]; - auto tmat_value = tmat.data(); - auto weight_value = weight.data(); - auto input_value = input->data(); - - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - - for (size_t k = 0; k < input_width; ++k) { - input_value[input_width * i + k] += - tmat_value[i * tmat_width + j] * - weight_value[weight_width * index + k]; +struct MatrixBitCodeFunctorMulGradError : public boost::static_visitor { + const framework::Tensor &tmat_; + const framework::Tensor &weight_; + framework::Tensor *input_; + + MatrixBitCodeFunctorMulGradError(const framework::Tensor &tmat, + const framework::Tensor &weight, + framework::Tensor *input) + : tmat_(tmat), weight_(weight), input_(input) {} + template + void operator()(const CodeTable &code_table) { + size_t num_samples = tmat_.dims()[0]; + size_t tmat_width = tmat_.dims()[1]; + size_t input_width = input_->dims()[1]; + size_t weight_width = weight_.dims()[1]; + auto tmat_value = tmat_.data(); + auto weight_value = weight_.data(); + auto input_value = input_->data(); + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + input_value[input_width * i + k] += + tmat_value[i * tmat_width + j] * + weight_value[weight_width * index + k]; + } } } } +}; + +template +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor &tmat, + const framework::Tensor &weight, + framework::Tensor *input) { + MatrixBitCodeFunctorMulGradError func(tmat, weight, input); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { - size_t num_samples = tmat->dims()[0]; - size_t o_width = tmat->dims()[1]; - auto* tmat_data = tmat->data(); - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - if (code->calc_bit(j)) { - tmat_data[i * o_width + j] -= 1; +struct MatrixBitCodeFunctorSub : public boost::static_visitor { + framework::Tensor *tmat_; + + explicit MatrixBitCodeFunctorSub(framework::Tensor *tmat) : tmat_(tmat) {} + + template + void operator()(const CodeTable &code_table) { + size_t num_samples = tmat_->dims()[0]; + size_t o_width = tmat_->dims()[1]; + auto *tmat_data = tmat_->data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat_data[i * o_width + j] -= 1; + } } } } +}; + +template +void MatrixBitCodeFunctor::Sub(framework::Tensor *tmat) { + MatrixBitCodeFunctorSub func(tmat); + code_table_.apply_visitor(func); } template class MatrixBitCodeFunctor; diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index cf43ad9d449430749086cc9b13e246d550c3661e..01e4889d34ad6e409f1b8a9c4bf783800187e863 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/variant.h" #if defined(_WIN32) #include @@ -99,24 +100,7 @@ inline int clz(const T& value) { inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } #endif // !_WIN32 -// set a code interface to create multiple code -class Code { - public: - virtual ~Code() {} - virtual size_t calc_index(int bit) const = 0; - virtual bool calc_bit(int bit) const = 0; - virtual int get_length() const = 0; -}; -// set a CodeTable interface to create multiple code table -class CodeTable { - public: - virtual Code* get_code(int64_t code) const = 0; - virtual size_t size() const = 0; - virtual int get_max_code_length() const = 0; - virtual ~CodeTable() {} -}; - -class SimpleCode : public Code { +class SimpleCode { public: SimpleCode(size_t code, size_t num_classes, const int64_t* ids) : c_(static_cast(ids[code]) + num_classes) {} @@ -138,7 +122,7 @@ class SimpleCode : public Code { }; template -class CustomCode : public Code { +class CustomCode { public: CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, const int64_t* ids, int index) { @@ -155,11 +139,11 @@ class CustomCode : public Code { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - size_t calc_index(int bit) const override { return ptable_data_[bit]; } - bool calc_bit(int bit) const override { return pcode_data_[bit]; } + size_t calc_index(int bit) const { return ptable_data_[bit]; } + bool calc_bit(int bit) const { return pcode_data_[bit]; } // NOTE: this function is not thread-safe. - int get_length() const override { + int get_length() const { if (length_ < 0) { auto len = seq_len_; length_ = @@ -177,46 +161,32 @@ class CustomCode : public Code { mutable int length_{-1}; }; -class SimpleCodeTable : public CodeTable { +class SimpleCodeTable { public: SimpleCodeTable(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids) {} - Code* get_code(int64_t code) const { - auto it = codes_.find(code); - if (it != codes_.end()) { - return it->second.get(); - } - auto* result = new SimpleCode(code, num_classes_, ids_); - codes_.emplace(code, std::unique_ptr(result)); - return result; + SimpleCode get_code(int64_t code) const { + return SimpleCode(code, num_classes_, ids_); } size_t size() const { return num_classes_; } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } private: - mutable std::map> codes_; - size_t num_classes_; const int64_t* ids_; }; template -class CustomCodeTable : public CodeTable { +class CustomCodeTable { public: CustomCodeTable(const framework::Tensor& ptable, const framework::Tensor& pcode, const int64_t* ids) : ptable_(ptable), pcode_(pcode), ids_(ids) {} - Code* get_code(int64_t code) const { - auto it = codes_.find(code); - if (it != codes_.end()) { - return it->second.get(); - } - auto* result = new CustomCode(ptable_, pcode_, ids_, code); - codes_.emplace(code, std::unique_ptr(result)); - return result; + CustomCode get_code(int64_t code) const { + return CustomCode(ptable_, pcode_, ids_, code); } size_t size() const { return static_cast(ptable_.dims()[1]); } @@ -225,25 +195,26 @@ class CustomCodeTable : public CodeTable { } private: - mutable std::unordered_map> codes_; const framework::Tensor& ptable_; const framework::Tensor& pcode_; const int64_t* ids_; }; +using CodeTable = boost::variant>; + template class MatrixBitCodeFunctor { public: MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids), - code_table_(new SimpleCodeTable(num_classes, ids)) {} + code_table_(SimpleCodeTable(num_classes, ids)) {} MatrixBitCodeFunctor(const framework::Tensor& ptable, const framework::Tensor& pcode, const int64_t* ids) : num_classes_(static_cast(ptable.dims()[1])), ids_(ids), - code_table_(new CustomCodeTable(ptable, pcode, ids)) {} + code_table_(CustomCodeTable(ptable, pcode, ids)) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ @@ -293,7 +264,7 @@ class MatrixBitCodeFunctor { size_t num_classes_; const int64_t* ids_; - std::unique_ptr code_table_; + CodeTable code_table_; }; } // namespace math } // namespace operators