diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index f4f2b769d5e47d8fba8d08476df4cd8e54133551..7262f8cc052ab55ba382f840f16c84018f8ef70e 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -140,58 +140,6 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown, - bool is_test) { - if (is_test) { - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - return -1; - } else { - return iter->second; - } - } - - rwlock_->RDLock(); - auto iter = id_to_index_.find(key); - if (iter == id_to_index_.end()) { - rwlock_->UNLock(); - if (!auto_grown) { - PADDLE_THROW("key %d not found", key); - } - rwlock_->WRLock(); - auto map_size = id_to_index_.size(); - auto vector_size = rows_.size(); - if (map_size != vector_size) { - rwlock_->UNLock(); - PADDLE_THROW( - "id_to_index_ size %d should have the same size with rows_ %d", - map_size, vector_size); - } - auto write_iter = id_to_index_.find(key); - if (write_iter == id_to_index_.end()) { - int row_num = rows_.size(); - if (row_num == value_->dims()[0]) { - rwlock_->UNLock(); - PADDLE_THROW("selected rows is full, then length exceed %d", row_num); - } - // key logic to put a key into id_to_index_ - rows_.push_back(key); - auto index = static_cast(rows_.size() - 1); - id_to_index_[key] = index; - rwlock_->UNLock(); - return index; - } else { - auto index = write_iter->second; - rwlock_->UNLock(); - return index; - } - } else { - auto index = iter->second; - rwlock_->UNLock(); - return index; - } -} - void SelectedRows::SyncIndex() { rwlock_->WRLock(); id_to_index_.clear(); diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index d3e0f2168b7e946739c69e8c433aefb6410d7f2b..6c31dada686b9493f85063039781dc46324d4fe4 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -118,7 +118,55 @@ class SelectedRows { * * @return index of the key. */ - int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %d not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %d should have the same size with rows_ %d", + map_size, vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } + } void SyncIndex(); /* diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 2967586949846149e07d5f48bf156bda151fda8c..9a0cf8701fbb624f7f8d978d075a6f22f78826c0 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -142,7 +142,7 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::LoDTensor& tmat, for (size_t k = 0; k < input_width; ++k) { int64_t row_index = - weight->AutoGrownIndex(static_cast(index), false); + weight->AutoGrownIndex(static_cast(index), false, true); weight_value[row_index * weight_width + k] += tmat_value[i * tmat_width + j] * input_value[input_width * i + k];