diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 5a6e64b6f87d33249f0153e5f391deaf78e53de5..dbf4f5e3325c324b0676393e346cae486859b993 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -23,12 +23,14 @@ void MatrixBitCodeFunctor::Add(const framework::Tensor& vec, framework::Tensor* tmat) { size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; + auto* tmat_data = tmat->data(); + auto* vec_data = vec.data(); for (size_t i = 0; i < batch_size; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); - tmat->data()[i * width + j] += vec.data()[index]; + tmat_data[i * width + j] += vec_data[index]; } } } @@ -38,12 +40,14 @@ void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, framework::Tensor* vec) { size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; + auto* vec_data = vec->data(); + auto* tmat_data = tmat.data(); for (size_t i = 0; i < batch_size; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); - vec->data()[index] += tmat.data()[i * width + j]; + vec_data[index] += tmat_data[i * width + j]; } } } @@ -53,14 +57,15 @@ void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec) { size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; + auto* vec_data = vec->mutable_value()->data(); + auto* tmat_data = tmat.data(); for (size_t i = 0; i < batch_size; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); int64_t row_index = vec->GetIndexFromId(static_cast(index)); - vec->mutable_value()->data()[row_index] += - tmat.data()[i * width + j]; + vec_data[row_index] += tmat_data[i * width + j]; } } } @@ -70,6 +75,8 @@ void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum) { size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; + auto* tmat_data = tmat.data(); + auto* sum_data = sum->data(); for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); auto code = code_table_->get_code(i); @@ -78,10 +85,10 @@ void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, if (code->calc_bit(j)) { // calc_bit starts from right most bit, while data in tmat[i] is in the // reverse order. - sm += tmat.data()[i * o_width + j]; + sm += tmat_data[i * o_width + j]; } } - sum->data()[i] = scale_sum * sm; + sum_data[i] = scale_sum * sm; } } @@ -217,12 +224,13 @@ template void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { size_t num_samples = tmat->dims()[0]; size_t o_width = tmat->dims()[1]; + auto* tmat_data = tmat->data(); for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { if (code->calc_bit(j)) { - tmat->data()[i * o_width + j] -= 1; + tmat_data[i * o_width + j] -= 1; } } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 35ca73802b48982ddf3ed7485b56f50221c9f28c..ba1745b86dabf2505a0874bb318d3d94bdb5af18 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -140,13 +140,13 @@ template class CustomCode : public Code { public: CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, - const int64_t* ids, int index) - : ids_(ids), index_(index) { - ptable_ = ptable.Slice(index, index + 1); - pcode_ = pcode.Slice(index, index + 1); + const int64_t* ids, int index) { + seq_len_ = ptable.dims()[1]; + ptable_data_ = ptable.data() + seq_len_ * index; + pcode_data_ = pcode.data() + seq_len_ * index; } /** - * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * Here the id of root should be 1 rather than 0, thus the encoding of class c * is `c + num_classes` and all siblings can get the same weight indice using * prefixes. * Weight index is the prefixes of encoding, thus leave out the right most @@ -154,26 +154,26 @@ class CustomCode : public Code { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - size_t calc_index(int bit) const { return ptable_.data()[bit]; } - bool calc_bit(int bit) const { return pcode_.data()[bit]; } - int get_length() const { - int length = 0; + size_t calc_index(int bit) const override { return ptable_data_[bit]; } + bool calc_bit(int bit) const override { return pcode_data_[bit]; } - for (int i = 0; i < static_cast(ptable_.dims()[1]); i++) { - if (ptable_.data()[i] >= 0) { - length++; - } else { - return length; - } + // NOTE: this function is not thread-safe. + int get_length() const override { + if (length_ < 0) { + auto len = seq_len_; + length_ = + static_cast(std::find_if(ptable_data_, ptable_data_ + len, + [](const T& val) { return val < 0; }) - + ptable_data_); } - return length; + return length_; } private: - framework::Tensor ptable_; - framework::Tensor pcode_; - const int64_t* ids_; - const int index_; + int64_t seq_len_; + const T* ptable_data_; + const T* pcode_data_; + mutable int length_{-1}; }; class SimpleCodeTable : public CodeTable {