From 42470f14b77e71a53c25cf318c69c4ca019bb593 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 23 Nov 2018 06:43:42 +0000 Subject: [PATCH] test=develop --- paddle/fluid/framework/selected_rows.cc | 52 ------------------- paddle/fluid/framework/selected_rows.h | 50 +++++++++++++++++- .../fluid/operators/math/matrix_bit_code.cc | 2 +- 3 files changed, 50 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index f4f2b769d5..7262f8cc05 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -140,58 +140,6 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown, - bool is_test) { - if (is_test) { - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - return -1; - } else { - return iter->second; - } - } - - rwlock_->RDLock(); - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - rwlock_->UNLock(); - if (!auto_grown) { - PADDLE_THROW("key %d not found", key); - } - rwlock_->WRLock(); - auto map_size = id_to_index_.size(); - auto vector_size = rows_.size(); - if (map_size != vector_size) { - rwlock_->UNLock(); - PADDLE_THROW( - "id_to_index_ size %d should have the same size with rows_ %d", - map_size, vector_size); - } - auto write_iter = id_to_index_.find(key); - if (write_iter == id_to_index_.end()) { - int row_num = rows_.size(); - if (row_num == value_->dims()[0]) { - rwlock_->UNLock(); - PADDLE_THROW("selected rows is full, then length exceed %d", row_num); - } - // key logic to put a key into id_to_index_ - rows_.push_back(key); - auto index = static_cast(rows_.size() - 1); - id_to_index_[key] = index; - rwlock_->UNLock(); - return index; - } else { - auto index = write_iter->second; - rwlock_->UNLock(); - return index; - } - } else { - auto index = iter->second; - rwlock_->UNLock(); - return index; - } -} - void SelectedRows::SyncIndex() { rwlock_->WRLock(); id_to_index_.clear(); diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index d3e0f2168b..6c31dada68 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -118,7 +118,55 @@ class SelectedRows { * * @return index of the key. */ - int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %d not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %d should have the same size with rows_ %d", + map_size, vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } + } void SyncIndex(); /* diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 2967586949..9a0cf8701f 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -142,7 +142,7 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, for (size_t k = 0; k < input_width; ++k) { int64_t row_index = - weight->AutoGrownIndex(static_cast(index), false); + weight->AutoGrownIndex(static_cast(index), false, true); weight_value[row_index * weight_width + k] += tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; -- GitLab