提交 42470f14 编写于 作者: J JiabinYang

test=develop

上级 0fca1684
......@@ -140,58 +140,6 @@ bool SelectedRows::HasKey(int64_t key) const {
: true;
}
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
bool is_test) {
if (is_test) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
rwlock_->RDLock();
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
rwlock_->UNLock();
if (!auto_grown) {
PADDLE_THROW("key %d not found", key);
}
rwlock_->WRLock();
auto map_size = id_to_index_.size();
auto vector_size = rows_.size();
if (map_size != vector_size) {
rwlock_->UNLock();
PADDLE_THROW(
"id_to_index_ size %d should have the same size with rows_ %d",
map_size, vector_size);
}
auto write_iter = id_to_index_.find(key);
if (write_iter == id_to_index_.end()) {
int row_num = rows_.size();
if (row_num == value_->dims()[0]) {
rwlock_->UNLock();
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
}
// key logic to put a key into id_to_index_
rows_.push_back(key);
auto index = static_cast<int64_t>(rows_.size() - 1);
id_to_index_[key] = index;
rwlock_->UNLock();
return index;
} else {
auto index = write_iter->second;
rwlock_->UNLock();
return index;
}
} else {
auto index = iter->second;
rwlock_->UNLock();
return index;
}
}
void SelectedRows::SyncIndex() {
rwlock_->WRLock();
id_to_index_.clear();
......
......@@ -118,7 +118,55 @@ class SelectedRows {
*
* @return index of the key.
*/
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false) {
if (is_test) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
rwlock_->RDLock();
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
rwlock_->UNLock();
if (!auto_grown) {
PADDLE_THROW("key %d not found", key);
}
rwlock_->WRLock();
auto map_size = id_to_index_.size();
auto vector_size = rows_.size();
if (map_size != vector_size) {
rwlock_->UNLock();
PADDLE_THROW(
"id_to_index_ size %d should have the same size with rows_ %d",
map_size, vector_size);
}
auto write_iter = id_to_index_.find(key);
if (write_iter == id_to_index_.end()) {
int row_num = rows_.size();
if (row_num == value_->dims()[0]) {
rwlock_->UNLock();
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
}
// key logic to put a key into id_to_index_
rows_.push_back(key);
auto index = static_cast<int64_t>(rows_.size() - 1);
id_to_index_[key] = index;
rwlock_->UNLock();
return index;
} else {
auto index = write_iter->second;
rwlock_->UNLock();
return index;
}
} else {
auto index = iter->second;
rwlock_->UNLock();
return index;
}
}
void SyncIndex();
/*
......
......@@ -142,7 +142,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::LoDTensor& tmat,
for (size_t k = 0; k < input_width; ++k) {
int64_t row_index =
weight->AutoGrownIndex(static_cast<int64_t>(index), false);
weight->AutoGrownIndex(static_cast<int64_t>(index), false, true);
weight_value[row_index * weight_width + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册