diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake index a3599dd798c07f57ed82e3f25b6bb9fc4f8bdc3a..623c53f4f75bbd217c157bcdda0cb12c510269ee 100644 --- a/cmake/external/python.cmake +++ b/cmake/external/python.cmake @@ -18,8 +18,8 @@ ENDIF() INCLUDE(python_module) -FIND_PACKAGE(PythonInterp ${PY_VERSION}) -FIND_PACKAGE(PythonLibs ${PY_VERSION}) +FIND_PACKAGE(PythonInterp ${PY_VERSION} REQUIRED) +FIND_PACKAGE(PythonLibs ${PY_VERSION} REQUIRED) if(WIN32) execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" @@ -79,6 +79,5 @@ IF(PYTHONINTERP_FOUND) "please use pip to upgrade protobuf. pip install -U protobuf") ENDIF() ENDIF(PYTHONINTERP_FOUND) - INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR}) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index b73a32af89e882ac02623dd1d312f400a78fc47a..d212e6f8437e69e71c010b6af27a33ff5e39e1e1 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -150,19 +150,27 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { label.data())); } - auto& place = *ctx.template device_context().eigen_device(); - auto pre_out_mat = EigenMatrix::From(pre_out); - auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - auto out_grad_mat = EigenMatrix::From(out_grad); + // softrelu derivative - Eigen::array bcast{1, static_cast(pre_out_grad.dims()[1])}; + auto blas = math::GetBlas(ctx); - // softrelu derivative - pre_out_grad_mat.device(place) = - static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); + auto* pre_out_grad_data = pre_out_grad.data(); + auto* pre_out_data = pre_out.data(); + auto n = pre_out.numel(); + blas.VEXP(n, pre_out_data, pre_out_grad_data); + blas.VINV(n, pre_out_grad_data, pre_out_grad_data); + for (int64_t i = 0; i < n; ++i) { + pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; + } bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) - pre_out_grad_mat.device(place) = - pre_out_grad_mat * out_grad_mat.broadcast(bcast); + auto* out_grad_data = out_grad.data(); + + int64_t dim0 = pre_out_grad.dims()[0]; + int64_t dim1 = pre_out_grad.dims()[1]; + for (int64_t i = 0; i < dim0; ++i) { + T tmp = out_grad_data[i]; + blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1); + } // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 9f3a81f22cc52bef719f472e43f91bc81dfe2af6..f67f57827bc03e134bf87edd5bf033adb5098916 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -181,6 +181,9 @@ class Blas { const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, framework::Tensor* mat_out, T beta) const; + template + void VINV(int n, const T* a, T* y) const; + private: const DeviceContext& context_; }; @@ -282,6 +285,11 @@ class BlasT : private Blas { Base()->template BatchedGEMM(args...); } + template + void VINV(ARGS... args) const { + Base()->template VINV(args...); + } + private: const Blas* Base() const { return static_cast*>(this); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index c84087bb1e4849b27d53e05f046c93f631150f6f..972366bc093f4b7f0a090cf31213f75ccd89fd82 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -118,6 +118,11 @@ struct CBlas { static void VPOW(ARGS... args) { platform::dynload::vsPowx(args...); } + + template + static void VINV(ARGS... args) { + platform::dynload::vsInv(args...); + } }; template <> @@ -213,6 +218,11 @@ struct CBlas { static void VPOW(ARGS... args) { platform::dynload::vdPowx(args...); } + + template + static void VINV(ARGS... args) { + platform::dynload::vdInv(args...); + } }; #else @@ -603,6 +613,17 @@ void Blas::MatMul(const framework::Tensor &mat_a, dim_a.stride_, dim_b.stride_); } } +template +template +void Blas::VINV(int n, const T *a, T *y) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VINV(n, a, y); +#else + for (int i = 0; i < n; ++i) { + y[i] = 1.0 / a[i]; + } +#endif +} } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 5a6e64b6f87d33249f0153e5f391deaf78e53de5..d55e832cc2d9a4a5e2cb7fe5cf451a1205601951 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -14,218 +14,380 @@ limitations under the License. */ #include "paddle/fluid/operators/math/matrix_bit_code.h" #include +#include + namespace paddle { namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(const framework::Tensor& vec, - framework::Tensor* tmat) { - size_t batch_size = tmat->dims()[0]; - size_t width = tmat->dims()[1]; - for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - tmat->data()[i * width + j] += vec.data()[index]; +struct MatrixBitCodeFunctorAdd : public boost::static_visitor { + const framework::Tensor &vec_; + framework::Tensor *tmat_; + + MatrixBitCodeFunctorAdd(const framework::Tensor &vec, framework::Tensor *tmat) + : vec_(vec), tmat_(tmat) {} + + template + void operator()(const CodeTable &code_table) { + size_t batch_size = tmat_->dims()[0]; + size_t width = tmat_->dims()[1]; + auto *tmat_data = tmat_->data(); + auto *vec_data = vec_.data(); + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + tmat_data[i * width + j] += vec_data[index]; + } } } +}; + +template +void MatrixBitCodeFunctor::Add(const framework::Tensor &vec, + framework::Tensor *tmat) { + MatrixBitCodeFunctorAdd func(vec, tmat); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, - framework::Tensor* vec) { - size_t batch_size = tmat.dims()[0]; - size_t width = tmat.dims()[1]; - for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - vec->data()[index] += tmat.data()[i * width + j]; +struct MatrixBitCodeFunctorAddGrad : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::Tensor *vec_; + MatrixBitCodeFunctorAddGrad(const framework::Tensor &tmat, + framework::Tensor *vec) + : tmat_(tmat), vec_(vec) {} + + template + void operator()(const CodeTable &table) { + size_t batch_size = tmat_.dims()[0]; + size_t width = tmat_.dims()[1]; + auto *vec_data = vec_->data(); + auto *tmat_data = tmat_.data(); + for (size_t i = 0; i < batch_size; ++i) { + auto code = table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + vec_data[index] += tmat_data[i * width + j]; + } } } +}; + +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor &tmat, + framework::Tensor *vec) { + MatrixBitCodeFunctorAddGrad func(tmat, vec); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, - framework::SelectedRows* vec) { - size_t batch_size = tmat.dims()[0]; - size_t width = tmat.dims()[1]; - for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - int64_t row_index = vec->GetIndexFromId(static_cast(index)); - vec->mutable_value()->data()[row_index] += - tmat.data()[i * width + j]; +struct MatrixBitCodeFunctorSelectedRowsAddGrad + : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::SelectedRows *vec_; + + MatrixBitCodeFunctorSelectedRowsAddGrad(const framework::Tensor &tmat, + framework::SelectedRows *vec) + : tmat_(tmat), vec_(vec) {} + + template + void operator()(const CodeTable &code_table) { + size_t batch_size = tmat_.dims()[0]; + size_t width = tmat_.dims()[1]; + auto *vec_data = vec_->mutable_value()->template data(); + auto *tmat_data = tmat_.data(); + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + int64_t row_index = vec_->GetIndexFromId(static_cast(index)); + vec_data[row_index] += tmat_data[i * width + j]; + } } } +}; + +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor &tmat, + framework::SelectedRows *vec) { + MatrixBitCodeFunctorSelectedRowsAddGrad func(tmat, vec); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, - framework::Tensor* sum, T scale_sum) { - size_t num_samples = tmat.dims()[0]; - size_t o_width = tmat.dims()[1]; - for (size_t i = 0; i < num_samples; ++i) { - T sm = static_cast(0.0); - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - if (code->calc_bit(j)) { - // calc_bit starts from right most bit, while data in tmat[i] is in the - // reverse order. - sm += tmat.data()[i * o_width + j]; +struct MatrixBitCodeFunctorSum : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::Tensor *sum_; + T scale_sum_; + + MatrixBitCodeFunctorSum(const framework::Tensor &tmat, framework::Tensor *sum, + T scale_sum) + : tmat_(tmat), sum_(sum), scale_sum_(scale_sum) {} + + template + void operator()(const CodeTable &code_table) { + size_t num_samples = tmat_.dims()[0]; + size_t o_width = tmat_.dims()[1]; + auto *tmat_data = tmat_.data(); + auto *sum_data = sum_->data(); + for (size_t i = 0; i < num_samples; ++i) { + T sm = static_cast(0.0); + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + // calc_bit starts from right most bit, while data in tmat[i] is in + // the + // reverse order. + sm += tmat_data[i * o_width + j]; + } } + sum_data[i] = scale_sum_ * sm; } - sum->data()[i] = scale_sum * sm; } +}; + +template +void MatrixBitCodeFunctor::Sum(const framework::Tensor &tmat, + framework::Tensor *sum, T scale_sum) { + MatrixBitCodeFunctorSum func(tmat, sum, scale_sum); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, - const framework::Tensor& weight, - const framework::Tensor& input) { - auto blas = - GetBlas(platform::CPUDeviceContext()); - size_t num_samples = tmat->dims()[0]; - size_t tmat_width = tmat->dims()[1]; - size_t input_width = input.dims()[1]; - size_t weight_width = weight.dims()[1]; - auto tmat_value = tmat->data(); - auto weight_value = weight.data(); - auto input_value = input.data(); - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - const T* input_row = input_value + input_width * i; - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - const T* weight_row = weight_value + weight_width * index; - T sum = static_cast(0.0); - sum = blas.DOT(input_width, weight_row, input_row); - tmat_value[i * tmat_width + j] += sum; +struct MatrixBitCodeFunctorMul : public boost::static_visitor { + framework::Tensor *tmat_; + const framework::Tensor &weight_; + const framework::Tensor &input_; + + MatrixBitCodeFunctorMul(framework::Tensor *tmat, + const framework::Tensor &weight, + const framework::Tensor &input) + : tmat_(tmat), weight_(weight), input_(input) {} + + template + void operator()(const CodeTable &code_table) { + auto blas = + GetBlas(platform::CPUDeviceContext()); + size_t num_samples = tmat_->dims()[0]; + size_t tmat_width = tmat_->dims()[1]; + size_t input_width = input_.dims()[1]; + size_t weight_width = weight_.dims()[1]; + auto tmat_value = tmat_->data(); + auto weight_value = weight_.data(); + auto input_value = input_.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + const T *input_row = input_value + input_width * i; + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + const T *weight_row = weight_value + weight_width * index; + T sum = blas.DOT(input_width, weight_row, input_row); + tmat_value[i * tmat_width + j] += sum; + } } } +}; + +template +void MatrixBitCodeFunctor::Mul(framework::Tensor *tmat, + const framework::Tensor &weight, + const framework::Tensor &input) { + MatrixBitCodeFunctorMul func(tmat, weight, input); + code_table_.apply_visitor(func); } +template +class ReservedVector : public std::vector { + public: + ReservedVector() { this->reserve(N); } +}; + template -void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, - framework::Tensor* weight, - const framework::Tensor& input) { - auto blas = - GetBlas(platform::CPUDeviceContext()); - size_t num_samples = tmat.dims()[0]; - size_t input_width = input.dims()[1]; - size_t tmat_width = tmat.dims()[1]; - size_t weight_width = weight->dims()[1]; - auto tmat_value = tmat.data(); - auto weight_value = weight->data(); - auto input_value = input.data(); - - std::unordered_map>> ops; - - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - const T* input_value_row = input_value + input_width * i; - const T* tmat_row = tmat_value + i * tmat_width; - for (int j = 0; j < code_length; ++j) { - ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); +struct MatrixBitCodeFunctorMulGradWeight : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::Tensor *weight_; + const framework::Tensor &input_; + MatrixBitCodeFunctorMulGradWeight(const framework::Tensor &tmat, + framework::Tensor *weight, + const framework::Tensor &input) + : tmat_(tmat), weight_(weight), input_(input) {} + template + void operator()(const CodeTable &code_table) { + auto blas = + GetBlas(platform::CPUDeviceContext()); + size_t num_samples = tmat_.dims()[0]; + size_t input_width = input_.dims()[1]; + size_t tmat_width = tmat_.dims()[1]; + size_t weight_width = weight_->dims()[1]; + auto tmat_value = tmat_.data(); + auto weight_value = weight_->data(); + auto input_value = input_.data(); + + std::map, 8u>> ops; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + const T *input_value_row = input_value + input_width * i; + const T *tmat_row = tmat_value + i * tmat_width; + for (int j = 0; j < code_length; ++j) { + ops[code.calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } } - } - for (auto& op : ops) { - auto& op_in_row = op.second; - for (auto& pair : op_in_row) { - auto& scale = pair.first; - auto* input_row = pair.second; - T* weight_row = weight_value + op.first * weight_width; - blas.AXPY(input_width, scale, input_row, weight_row); + for (auto &op : ops) { + auto &op_in_row = op.second; + for (auto &pair : op_in_row) { + auto &scale = pair.first; + auto *input_row = pair.second; + T *weight_row = weight_value + op.first * weight_width; + blas.AXPY(input_width, scale, input_row, weight_row); + } } } +}; + +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor &tmat, + framework::Tensor *weight, + const framework::Tensor &input) { + MatrixBitCodeFunctorMulGradWeight func(tmat, weight, input); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, - framework::SelectedRows* weight, - const framework::Tensor& input) { - auto blas = - GetBlas(platform::CPUDeviceContext()); - size_t num_samples = tmat.dims()[0]; - size_t input_width = input.dims()[1]; - size_t tmat_width = tmat.dims()[1]; - size_t weight_width = weight->value().dims()[1]; - auto tmat_value = tmat.data(); - auto weight_value = weight->mutable_value()->data(); - auto input_value = input.data(); - - std::unordered_map>> ops; - ops.reserve(weight->rows().size()); - - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - const T* input_value_row = input_value + input_width * i; - const T* tmat_row = tmat_value + i * tmat_width; - for (int j = 0; j < code_length; ++j) { - ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); +struct MatrixBitCodeFunctorMulGradWeightSR + : public boost::static_visitor { + const framework::Tensor &tmat_; + framework::SelectedRows *weight_; + const framework::Tensor &input_; + + MatrixBitCodeFunctorMulGradWeightSR(const framework::Tensor &tmat, + framework::SelectedRows *weight, + const framework::Tensor &input) + : tmat_(tmat), weight_(weight), input_(input) {} + + template + void operator()(const CodeTable &code_table) { + auto blas = + GetBlas(platform::CPUDeviceContext()); + size_t num_samples = tmat_.dims()[0]; + size_t input_width = input_.dims()[1]; + size_t tmat_width = tmat_.dims()[1]; + size_t weight_width = weight_->value().dims()[1]; + auto tmat_value = tmat_.data(); + auto weight_value = weight_->mutable_value()->data(); + auto input_value = input_.data(); + + std::unordered_map>> ops; + ops.reserve(weight_->rows().size()); + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + const T *input_value_row = input_value + input_width * i; + const T *tmat_row = tmat_value + i * tmat_width; + for (int j = 0; j < code_length; ++j) { + ops[code.calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } } - } - for (auto& row : weight->rows()) { - auto& op_in_row = ops[row]; - for (auto& pair : op_in_row) { - auto& scale = pair.first; - auto* input_row = pair.second; - blas.AXPY(input_width, scale, input_row, weight_value); + for (auto &row : weight_->rows()) { + auto &op_in_row = ops[row]; + for (auto &pair : op_in_row) { + auto &scale = pair.first; + auto *input_row = pair.second; + blas.AXPY(input_width, scale, input_row, weight_value); + } + weight_value += weight_width; } - weight_value += weight_width; } +}; + +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor &tmat, + framework::SelectedRows *weight, + const framework::Tensor &input) { + MatrixBitCodeFunctorMulGradWeightSR func(tmat, weight, input); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor* input) { - size_t num_samples = tmat.dims()[0]; - size_t tmat_width = tmat.dims()[1]; - size_t input_width = input->dims()[1]; - size_t weight_width = weight.dims()[1]; - auto tmat_value = tmat.data(); - auto weight_value = weight.data(); - auto input_value = input->data(); - - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - - for (size_t k = 0; k < input_width; ++k) { - input_value[input_width * i + k] += - tmat_value[i * tmat_width + j] * - weight_value[weight_width * index + k]; +struct MatrixBitCodeFunctorMulGradError : public boost::static_visitor { + const framework::Tensor &tmat_; + const framework::Tensor &weight_; + framework::Tensor *input_; + + MatrixBitCodeFunctorMulGradError(const framework::Tensor &tmat, + const framework::Tensor &weight, + framework::Tensor *input) + : tmat_(tmat), weight_(weight), input_(input) {} + template + void operator()(const CodeTable &code_table) { + size_t num_samples = tmat_.dims()[0]; + size_t tmat_width = tmat_.dims()[1]; + size_t input_width = input_->dims()[1]; + size_t weight_width = weight_.dims()[1]; + auto tmat_value = tmat_.data(); + auto weight_value = weight_.data(); + auto input_value = input_->data(); + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + input_value[input_width * i + k] += + tmat_value[i * tmat_width + j] * + weight_value[weight_width * index + k]; + } } } } +}; + +template +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor &tmat, + const framework::Tensor &weight, + framework::Tensor *input) { + MatrixBitCodeFunctorMulGradError func(tmat, weight, input); + code_table_.apply_visitor(func); } template -void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { - size_t num_samples = tmat->dims()[0]; - size_t o_width = tmat->dims()[1]; - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table_->get_code(i); - int code_length = code->get_length(); - for (int j = 0; j < code_length; ++j) { - if (code->calc_bit(j)) { - tmat->data()[i * o_width + j] -= 1; +struct MatrixBitCodeFunctorSub : public boost::static_visitor { + framework::Tensor *tmat_; + + explicit MatrixBitCodeFunctorSub(framework::Tensor *tmat) : tmat_(tmat) {} + + template + void operator()(const CodeTable &code_table) { + size_t num_samples = tmat_->dims()[0]; + size_t o_width = tmat_->dims()[1]; + auto *tmat_data = tmat_->data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table.get_code(i); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat_data[i * o_width + j] -= 1; + } } } } +}; + +template +void MatrixBitCodeFunctor::Sub(framework::Tensor *tmat) { + MatrixBitCodeFunctorSub func(tmat); + code_table_.apply_visitor(func); } template class MatrixBitCodeFunctor; diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 35ca73802b48982ddf3ed7485b56f50221c9f28c..01e4889d34ad6e409f1b8a9c4bf783800187e863 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/variant.h" #if defined(_WIN32) #include @@ -98,24 +100,7 @@ inline int clz(const T& value) { inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } #endif // !_WIN32 -// set a code interface to create multiple code -class Code { - public: - virtual ~Code() {} - virtual size_t calc_index(int bit) const = 0; - virtual bool calc_bit(int bit) const = 0; - virtual int get_length() const = 0; -}; -// set a CodeTable interface to create multiple code table -class CodeTable { - public: - virtual std::unique_ptr get_code(int64_t code) const = 0; - virtual size_t size() const = 0; - virtual int get_max_code_length() const = 0; - virtual ~CodeTable() {} -}; - -class SimpleCode : public Code { +class SimpleCode { public: SimpleCode(size_t code, size_t num_classes, const int64_t* ids) : c_(static_cast(ids[code]) + num_classes) {} @@ -137,16 +122,16 @@ class SimpleCode : public Code { }; template -class CustomCode : public Code { +class CustomCode { public: CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, - const int64_t* ids, int index) - : ids_(ids), index_(index) { - ptable_ = ptable.Slice(index, index + 1); - pcode_ = pcode.Slice(index, index + 1); + const int64_t* ids, int index) { + seq_len_ = ptable.dims()[1]; + ptable_data_ = ptable.data() + seq_len_ * index; + pcode_data_ = pcode.data() + seq_len_ * index; } /** - * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * Here the id of root should be 1 rather than 0, thus the encoding of class c * is `c + num_classes` and all siblings can get the same weight indice using * prefixes. * Weight index is the prefixes of encoding, thus leave out the right most @@ -154,36 +139,37 @@ class CustomCode : public Code { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - size_t calc_index(int bit) const { return ptable_.data()[bit]; } - bool calc_bit(int bit) const { return pcode_.data()[bit]; } - int get_length() const { - int length = 0; + size_t calc_index(int bit) const { return ptable_data_[bit]; } + bool calc_bit(int bit) const { return pcode_data_[bit]; } - for (int i = 0; i < static_cast(ptable_.dims()[1]); i++) { - if (ptable_.data()[i] >= 0) { - length++; - } else { - return length; - } + // NOTE: this function is not thread-safe. + int get_length() const { + if (length_ < 0) { + auto len = seq_len_; + length_ = + static_cast(std::find_if(ptable_data_, ptable_data_ + len, + [](const T& val) { return val < 0; }) - + ptable_data_); } - return length; + return length_; } private: - framework::Tensor ptable_; - framework::Tensor pcode_; - const int64_t* ids_; - const int index_; + int64_t seq_len_; + const T* ptable_data_; + const T* pcode_data_; + mutable int length_{-1}; }; -class SimpleCodeTable : public CodeTable { +class SimpleCodeTable { public: SimpleCodeTable(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids) {} - std::unique_ptr get_code(int64_t code) const { - std::unique_ptr coder(new SimpleCode(code, num_classes_, ids_)); - return coder; + + SimpleCode get_code(int64_t code) const { + return SimpleCode(code, num_classes_, ids_); } + size_t size() const { return num_classes_; } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } @@ -193,15 +179,14 @@ class SimpleCodeTable : public CodeTable { }; template -class CustomCodeTable : public CodeTable { +class CustomCodeTable { public: CustomCodeTable(const framework::Tensor& ptable, const framework::Tensor& pcode, const int64_t* ids) : ptable_(ptable), pcode_(pcode), ids_(ids) {} - std::unique_ptr get_code(int64_t code) const { - std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); - return coder; + CustomCode get_code(int64_t code) const { + return CustomCode(ptable_, pcode_, ids_, code); } size_t size() const { return static_cast(ptable_.dims()[1]); } @@ -215,19 +200,21 @@ class CustomCodeTable : public CodeTable { const int64_t* ids_; }; +using CodeTable = boost::variant>; + template class MatrixBitCodeFunctor { public: MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids), - code_table_(new SimpleCodeTable(num_classes, ids)) {} + code_table_(SimpleCodeTable(num_classes, ids)) {} MatrixBitCodeFunctor(const framework::Tensor& ptable, const framework::Tensor& pcode, const int64_t* ids) : num_classes_(static_cast(ptable.dims()[1])), ids_(ids), - code_table_(new CustomCodeTable(ptable, pcode, ids)) {} + code_table_(CustomCodeTable(ptable, pcode, ids)) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ @@ -277,7 +264,7 @@ class MatrixBitCodeFunctor { size_t num_classes_; const int64_t* ids_; - std::unique_ptr code_table_; + CodeTable code_table_; }; } // namespace math } // namespace operators diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index f0a973662360fd9ff35e1006cce937d86f3e563c..c3f9433503accf98d30ccaa57b9b4b8f3c68666a 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -82,6 +82,8 @@ extern void* mklml_dso_handle; __macro(vdSqr); \ __macro(vsPowx); \ __macro(vdPowx); \ + __macro(vsInv); \ + __macro(vdInv); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);