diff --git a/mindspore/ccsrc/dataset/util/auto_index.h b/mindspore/ccsrc/dataset/util/auto_index.h index e34d9e1288f7d001386c2f59c14c051cbcbc0415..2b4c2d68833c45d00a6becebea9c69e17269ca5d 100644 --- a/mindspore/ccsrc/dataset/util/auto_index.h +++ b/mindspore/ccsrc/dataset/util/auto_index.h @@ -18,6 +18,7 @@ #include #include +#include #include #include "dataset/util/btree.h" @@ -25,19 +26,20 @@ namespace mindspore { namespace dataset { -// This is a B+ tree with generated uint64_t value as key. -// Use minKey() function to query the min key. -// Use maxKey() function to query the max key. -// @tparam T -template -class AutoIndexObj : public BPlusTree { +/// This is a B+ tree with generated int64_t value as key. +/// Use minKey() function to query the min key. +/// Use maxKey() function to query the max key. +/// @tparam T +template > +class AutoIndexObj : public BPlusTree { public: - using my_tree = BPlusTree; + using my_tree = BPlusTree; using key_type = typename my_tree::key_type; using value_type = typename my_tree::value_type; - explicit AutoIndexObj(const typename my_tree::value_allocator &alloc = Allocator{std::make_shared()}) - : my_tree::BPlusTree(alloc), inx_(kMinKey) {} + AutoIndexObj() : my_tree::BPlusTree(), inx_(kMinKey) {} + + explicit AutoIndexObj(const Allocator &alloc) : my_tree::BPlusTree(alloc), inx_(kMinKey) {} ~AutoIndexObj() = default; @@ -52,6 +54,14 @@ class AutoIndexObj : public BPlusTree { return my_tree::DoInsert(my_inx, val); } + Status insert(std::unique_ptr &&val, key_type *key = nullptr) { + key_type my_inx = inx_.fetch_add(1); + if (key) { + *key = my_inx; + } + return my_tree::DoInsert(my_inx, std::move(val)); + } + // Insert a vector of objects into the tree. // @param v // @return diff --git a/mindspore/ccsrc/dataset/util/btree.h b/mindspore/ccsrc/dataset/util/btree.h index 42c0499e5fd33058f9883d418218bc87bcf91611..72e3d16351c84d702bad811e91972436163b83f7 100644 --- a/mindspore/ccsrc/dataset/util/btree.h +++ b/mindspore/ccsrc/dataset/util/btree.h @@ -44,12 +44,14 @@ struct BPlusTreeTraits { static constexpr bool kAppendMode = false; }; -// Implementation of B+ tree -// @tparam K -// @tparam V -// @tparam C -// @tparam T -template , typename T = BPlusTreeTraits> +/// Implementation of B+ tree +/// @tparam K -- the type of key +/// @tparam V -- the type of value +/// @tparam A -- allocator +/// @tparam C -- comparison class +/// @tparam T -- trait +template , typename C = std::less, + typename T = BPlusTreeTraits> class BPlusTree { public: enum class IndexRc : char { @@ -87,11 +89,13 @@ class BPlusTree { using key_compare = C; using slot_type = typename T::slot_type; using traits = T; - using key_allocator = Allocator; - using value_allocator = Allocator; - using slot_allocator = Allocator; + using value_allocator = A; + using key_allocator = typename value_allocator::template rebind::other; + using slot_allocator = typename value_allocator::template rebind::other; - explicit BPlusTree(const value_allocator &alloc); + BPlusTree(); + + explicit BPlusTree(const Allocator &alloc); ~BPlusTree() noexcept; @@ -109,10 +113,15 @@ class BPlusTree { bool empty() const { return (size() == 0); } - // @param key - // @param value - // @return + /// @param key + /// @param value + /// @return Status DoInsert(const key_type &key, const value_type &value); + Status DoInsert(const key_type &key, std::unique_ptr &&value); + + // Update a new value for a given key. + std::unique_ptr DoUpdate(const key_type &key, const value_type &new_value); + std::unique_ptr DoUpdate(const key_type &key, std::unique_ptr &&new_value); void PopulateNumKeys(); @@ -144,7 +153,7 @@ class BPlusTree { virtual ~BaseNode() = default; protected: - RWLock rw_lock_; + mutable RWLock rw_lock_; value_allocator alloc_; private: @@ -267,7 +276,7 @@ class BPlusTree { // 50/50 split IndexRc Split(LeafNode *to); - IndexRc InsertIntoSlot(LockPathCB *insCB, slot_type slot, const key_type &key, std::shared_ptr value); + IndexRc InsertIntoSlot(LockPathCB *insCB, slot_type slot, const key_type &key, std::unique_ptr &&value); explicit LeafNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} @@ -275,11 +284,11 @@ class BPlusTree { slot_type slot_dir_[traits::kLeafSlots]; key_type keys_[traits::kLeafSlots]; - std::shared_ptr data_[traits::kLeafSlots]; + std::unique_ptr data_[traits::kLeafSlots]; slot_type slotuse_; }; - RWLock rw_lock_; + mutable RWLock rw_lock_; value_allocator alloc_; // All the leaf nodes. Used by the iterator to traverse all the key/values. List leaf_nodes_; @@ -319,8 +328,8 @@ class BPlusTree { return lo; } - IndexRc LeafInsertKeyValue(LockPathCB *ins_cb, LeafNode *node, const key_type &key, std::shared_ptr value, - key_type *split_key, LeafNode **split_node); + IndexRc LeafInsertKeyValue(LockPathCB *ins_cb, LeafNode *node, const key_type &key, + std::unique_ptr &&value, key_type *split_key, LeafNode **split_node); IndexRc InnerInsertKeyChild(InnerNode *node, const key_type &key, BaseNode *ptr, key_type *split_key, InnerNode **split_node); @@ -335,10 +344,11 @@ class BPlusTree { return child; } - IndexRc InsertKeyValue(LockPathCB *ins_cb, BaseNode *n, const key_type &key, std::shared_ptr value, + IndexRc InsertKeyValue(LockPathCB *ins_cb, BaseNode *n, const key_type &key, std::unique_ptr &&value, key_type *split_key, BaseNode **split_node); - IndexRc Locate(BaseNode *top, const key_type &key, LeafNode **ln, slot_type *s) const; + IndexRc Locate(RWLock *parent_lock, bool forUpdate, BaseNode *top, const key_type &key, LeafNode **ln, + slot_type *s) const; public: class Iterator : public std::iterator { @@ -346,19 +356,27 @@ class BPlusTree { using reference = BPlusTree::value_type &; using pointer = BPlusTree::value_type *; - explicit Iterator(BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0) {} + explicit Iterator(BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} + + Iterator(LeafNode *leaf, slot_type slot, bool locked = false) : cur_(leaf), slot_(slot), locked_(locked) {} + + ~Iterator(); + + explicit Iterator(const Iterator &); + + Iterator &operator=(const Iterator &lhs); - Iterator(LeafNode *leaf, slot_type slot) : cur_(leaf), slot_(slot) {} + Iterator(Iterator &&); - ~Iterator() = default; + Iterator &operator=(Iterator &&lhs); pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - const key_type &key() { return cur_->keys_[cur_->slot_dir_[slot_]]; } + const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } - const value_type &value() { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } // Prefix++ Iterator &operator++(); @@ -379,6 +397,7 @@ class BPlusTree { private: typename BPlusTree::LeafNode *cur_; slot_type slot_; + bool locked_; }; class ConstIterator : public std::iterator { @@ -386,11 +405,20 @@ class BPlusTree { using reference = BPlusTree::value_type &; using pointer = BPlusTree::value_type *; - explicit ConstIterator(const BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0) {} + explicit ConstIterator(const BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} + + ~ConstIterator(); + + ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) + : cur_(leaf), slot_(slot), locked_(locked) {} + + explicit ConstIterator(const ConstIterator &); + + ConstIterator &operator=(const ConstIterator &lhs); - ~ConstIterator() = default; + ConstIterator(ConstIterator &&); - ConstIterator(const LeafNode *leaf, slot_type slot) : cur_(leaf), slot_(slot) {} + ConstIterator &operator=(ConstIterator &&lhs); pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } @@ -398,7 +426,7 @@ class BPlusTree { const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } - const value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } // Prefix++ ConstIterator &operator++(); @@ -419,6 +447,7 @@ class BPlusTree { private: const typename BPlusTree::LeafNode *cur_; slot_type slot_; + bool locked_; }; Iterator begin(); @@ -435,6 +464,7 @@ class BPlusTree { // Locate the entry with key ConstIterator Search(const key_type &key) const; + Iterator Search(const key_type &key); value_type operator[](key_type key); }; diff --git a/mindspore/ccsrc/dataset/util/btree_impl.tpp b/mindspore/ccsrc/dataset/util/btree_impl.tpp index dab94002d5e167fd0b14519c9e708f925c88a2e6..f38070f367e7006b73db9b8d773daf837fb4d2ff 100644 --- a/mindspore/ccsrc/dataset/util/btree_impl.tpp +++ b/mindspore/ccsrc/dataset/util/btree_impl.tpp @@ -19,10 +19,10 @@ namespace mindspore { namespace dataset { -template -typename BPlusTree::IndexRc BPlusTree::InnerNode::Sort() { +template +typename BPlusTree::IndexRc BPlusTree::InnerNode::Sort() { // Build an inverse map. Basically it means keys[i] should be relocated to keys[inverse[i]]; - Allocator alloc(this->alloc_); + slot_allocator alloc(this->alloc_); slot_type *inverse = nullptr; try { inverse = alloc.allocate(traits::kInnerSlots); @@ -51,15 +51,15 @@ typename BPlusTree::IndexRc BPlusTree::InnerNode::Sort() slot_dir_[i] = i; } if (inverse != nullptr) { - alloc.deallocate(inverse); + alloc.deallocate(inverse, traits::kInnerSlots); inverse = nullptr; } return IndexRc::kOk; } -template -typename BPlusTree::IndexRc BPlusTree::InnerNode::Split(BPlusTree::InnerNode *to, - key_type *split_key) { +template +typename BPlusTree::IndexRc BPlusTree::InnerNode::Split( + BPlusTree::InnerNode *to, key_type *split_key) { DS_ASSERT(to); DS_ASSERT(to->slotuse_ == 0); // It is simpler to sort first, then split. Other alternative is to move key by key to the @@ -72,7 +72,7 @@ typename BPlusTree::IndexRc BPlusTree::InnerNode::Split( if (err != EOK) { return IndexRc::kUnexpectedError; } - err = memcpy_s(to->data_, sizeof(to->data_), data_ + mid + 1, (num_keys_to_move + 1) * sizeof(BaseNode * )); + err = memcpy_s(to->data_, sizeof(to->data_), data_ + mid + 1, (num_keys_to_move + 1) * sizeof(BaseNode *)); if (err != EOK) { return IndexRc::kUnexpectedError; } @@ -84,10 +84,9 @@ typename BPlusTree::IndexRc BPlusTree::InnerNode::Split( return IndexRc::kOk; } -template -typename BPlusTree::IndexRc -BPlusTree::InnerNode::InsertIntoSlot(slot_type slot, const key_type &key, - BPlusTree::BaseNode *ptr) { +template +typename BPlusTree::IndexRc BPlusTree::InnerNode::InsertIntoSlot( + slot_type slot, const key_type &key, BPlusTree::BaseNode *ptr) { if (is_full()) { return IndexRc::kSlotFull; } @@ -111,10 +110,10 @@ BPlusTree::InnerNode::InsertIntoSlot(slot_type slot, const key_type return IndexRc::kOk; } -template -typename BPlusTree::IndexRc BPlusTree::LeafNode::Sort() { +template +typename BPlusTree::IndexRc BPlusTree::LeafNode::Sort() { // Build an inverse map. Basically it means keys[i] should be relocated to keys[inverse[i]]; - Allocator alloc(this->alloc_); + slot_allocator alloc(this->alloc_); slot_type *inverse = nullptr; try { inverse = alloc.allocate(traits::kLeafSlots); @@ -143,14 +142,15 @@ typename BPlusTree::IndexRc BPlusTree::LeafNode::Sort() slot_dir_[i] = i; } if (inverse != nullptr) { - alloc.deallocate(inverse); + alloc.deallocate(inverse, traits::kLeafSlots); inverse = nullptr; } return IndexRc::kOk; } -template -typename BPlusTree::IndexRc BPlusTree::LeafNode::Split(BPlusTree::LeafNode *to) { +template +typename BPlusTree::IndexRc BPlusTree::LeafNode::Split( + BPlusTree::LeafNode *to) { DS_ASSERT(to); DS_ASSERT(to->slotuse_ == 0); // It is simpler to sort first, then split. Other alternative is to move key by key to the @@ -171,11 +171,10 @@ typename BPlusTree::IndexRc BPlusTree::LeafNode::Split(B return IndexRc::kOk; } -template -typename BPlusTree::IndexRc -BPlusTree::LeafNode::InsertIntoSlot(BPlusTree::LockPathCB *insCB, slot_type slot, - const key_type &key, - std::shared_ptr value) { +template +typename BPlusTree::IndexRc BPlusTree::LeafNode::InsertIntoSlot( + BPlusTree::LockPathCB *insCB, slot_type slot, const key_type &key, + std::unique_ptr &&value) { if (is_full()) { // If we need to do node split, we need to ensure all the intermediate nodes are locked exclusive. // Otherwise we need to do a retry. @@ -210,8 +209,9 @@ BPlusTree::LeafNode::InsertIntoSlot(BPlusTree::LockPathC return IndexRc::kOk; } -template -typename BPlusTree::IndexRc BPlusTree::AllocateInner(BPlusTree::InnerNode **p) { +template +typename BPlusTree::IndexRc BPlusTree::AllocateInner( + BPlusTree::InnerNode **p) { if (p == nullptr) { return IndexRc::kNullPointer; } @@ -224,14 +224,15 @@ typename BPlusTree::IndexRc BPlusTree::AllocateInner(BPl } catch (std::exception &e) { return IndexRc::kUnexpectedError; } - *p = new(ptr) InnerNode(alloc_); + *p = new (ptr) InnerNode(alloc_); all_.Prepend(ptr); stats_.inner_nodes_++; return IndexRc::kOk; } -template -typename BPlusTree::IndexRc BPlusTree::AllocateLeaf(BPlusTree::LeafNode **p) { +template +typename BPlusTree::IndexRc BPlusTree::AllocateLeaf( + BPlusTree::LeafNode **p) { if (p == nullptr) { return IndexRc::kNullPointer; } @@ -244,24 +245,22 @@ typename BPlusTree::IndexRc BPlusTree::AllocateLeaf(BPlu } catch (std::exception &e) { return IndexRc::kUnexpectedError; } - *p = new(ptr) LeafNode(alloc_); + *p = new (ptr) LeafNode(alloc_); all_.Prepend(ptr); stats_.leaves_++; return IndexRc::kOk; } -template -typename BPlusTree::IndexRc -BPlusTree::LeafInsertKeyValue(BPlusTree::LockPathCB *ins_cb, - BPlusTree::LeafNode *node, const key_type &key, - std::shared_ptr value, key_type *split_key, - BPlusTree::LeafNode **split_node) { +template +typename BPlusTree::IndexRc BPlusTree::LeafInsertKeyValue( + BPlusTree::LockPathCB *ins_cb, BPlusTree::LeafNode *node, const key_type &key, + std::unique_ptr &&value, key_type *split_key, BPlusTree::LeafNode **split_node) { bool duplicate; slot_type slot = FindSlot(node, key, &duplicate); if (duplicate) { return IndexRc::kDuplicateKey; } - IndexRc rc = node->InsertIntoSlot(ins_cb, slot, key, value); + IndexRc rc = node->InsertIntoSlot(ins_cb, slot, key, std::move(value)); if (rc == IndexRc::kSlotFull) { LeafNode *new_leaf = nullptr; rc = AllocateLeaf(&new_leaf); @@ -273,7 +272,7 @@ BPlusTree::LeafInsertKeyValue(BPlusTree::LockPathCB *ins *split_key = key; // Just insert the new key to the new leaf. No further need to move the keys // from one leaf to the other. - rc = new_leaf->InsertIntoSlot(nullptr, 0, key, value); + rc = new_leaf->InsertIntoSlot(nullptr, 0, key, std::move(value)); RETURN_IF_BAD_RC(rc); } else { // 50/50 split @@ -281,11 +280,11 @@ BPlusTree::LeafInsertKeyValue(BPlusTree::LockPathCB *ins RETURN_IF_BAD_RC(rc); *split_key = new_leaf->keys_[0]; if (LessThan(key, *split_key)) { - rc = node->InsertIntoSlot(nullptr, slot, key, value); + rc = node->InsertIntoSlot(nullptr, slot, key, std::move(value)); RETURN_IF_BAD_RC(rc); } else { slot -= node->slotuse_; - rc = new_leaf->InsertIntoSlot(nullptr, slot, key, value); + rc = new_leaf->InsertIntoSlot(nullptr, slot, key, std::move(value)); RETURN_IF_BAD_RC(rc); } } @@ -293,11 +292,10 @@ BPlusTree::LeafInsertKeyValue(BPlusTree::LockPathCB *ins return rc; } -template -typename BPlusTree::IndexRc -BPlusTree::InnerInsertKeyChild(BPlusTree::InnerNode *node, const key_type &key, - BPlusTree::BaseNode *ptr, - key_type *split_key, BPlusTree::InnerNode **split_node) { +template +typename BPlusTree::IndexRc BPlusTree::InnerInsertKeyChild( + BPlusTree::InnerNode *node, const key_type &key, BPlusTree::BaseNode *ptr, + key_type *split_key, BPlusTree::InnerNode **split_node) { bool duplicate; slot_type slot = FindSlot(node, key, &duplicate); if (duplicate) { @@ -333,12 +331,10 @@ BPlusTree::InnerInsertKeyChild(BPlusTree::InnerNode *nod return rc; } -template -typename BPlusTree::IndexRc -BPlusTree::InsertKeyValue(BPlusTree::LockPathCB *ins_cb, BPlusTree::BaseNode *n, - const key_type &key, - std::shared_ptr value, key_type *split_key, - BPlusTree::BaseNode **split_node) { +template +typename BPlusTree::IndexRc BPlusTree::InsertKeyValue( + BPlusTree::LockPathCB *ins_cb, BPlusTree::BaseNode *n, const key_type &key, + std::unique_ptr &&value, key_type *split_key, BPlusTree::BaseNode **split_node) { if (split_key == nullptr || split_node == nullptr) { return IndexRc::kUnexpectedError; } @@ -378,17 +374,36 @@ BPlusTree::InsertKeyValue(BPlusTree::LockPathCB *ins_cb, return IndexRc::kOk; } -template -typename BPlusTree::IndexRc -BPlusTree::Locate(BPlusTree::BaseNode *top, const key_type &key, - BPlusTree::LeafNode **ln, - slot_type *s) const { +template +typename BPlusTree::IndexRc BPlusTree::Locate(RWLock *parent_lock, + bool forUpdate, + BPlusTree::BaseNode *top, + const key_type &key, + BPlusTree::LeafNode **ln, + slot_type *s) const { if (ln == nullptr || s == nullptr) { return IndexRc::kNullPointer; } if (top == nullptr) { return IndexRc::kKeyNotFound; } + RWLock *myLock = nullptr; + if (parent_lock != nullptr) { + // Crabbing. Lock this node first, then unlock the parent. + myLock = &top->rw_lock_; + if (top->is_leafnode()) { + if (forUpdate) { + // We are holding the parent lock in S and try to lock this node with X. It is not possible to run + // into deadlock because no one will hold the child in X and trying to lock the parent in that order. + myLock->LockExclusive(); + } else { + myLock->LockShared(); + } + } else { + myLock->LockShared(); + } + parent_lock->Unlock(); + } if (top->is_leafnode()) { bool duplicate; auto *leaf = static_cast(top); @@ -398,22 +413,29 @@ BPlusTree::Locate(BPlusTree::BaseNode *top, const key_ty *ln = leaf; *s = slot; } else { + if (myLock != nullptr) { + myLock->Unlock(); + } return IndexRc::kKeyNotFound; } } else { auto *inner = static_cast(top); slot_type slot = FindSlot(inner, key); - return Locate(FindBranch(inner, slot), key, ln, s); + return Locate(myLock, forUpdate, FindBranch(inner, slot), key, ln, s); } + // We still have a S lock on the leaf node. Leave it there. The iterator will unlock it for us. return IndexRc::kOk; } -template -BPlusTree::BPlusTree(const value_allocator &alloc) +template +BPlusTree::BPlusTree() : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} + +template +BPlusTree::BPlusTree(const Allocator &alloc) : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} -template -BPlusTree::~BPlusTree() noexcept { +template +BPlusTree::~BPlusTree() noexcept { // We have a list of all the nodes allocated. Traverse them and free all the memory BaseNode *n = all_.head; BaseNode *t = nullptr; @@ -436,8 +458,8 @@ BPlusTree::~BPlusTree() noexcept { root_ = nullptr; } -template -Status BPlusTree::DoInsert(const key_type &key, const value_type &value) { +template +Status BPlusTree::DoInsert(const key_type &key, std::unique_ptr &&value) { IndexRc rc; if (root_ == nullptr) { UniqueLock lck(&rw_lock_); @@ -464,10 +486,7 @@ Status BPlusTree::DoInsert(const key_type &key, const value_type &va retry = false; BaseNode *new_child = nullptr; key_type new_key = key_type(); - // We don't store the value directly into the leaf node as it is expensive to move it during node split. - // Rather we store a pointer instead. The value_type must support the copy constructor. - std::shared_ptr ptr_value = std::make_shared(value); - rc = InsertKeyValue(&InsCB, root_, key, std::move(ptr_value), &new_key, &new_child); + rc = InsertKeyValue(&InsCB, root_, key, std::move(value), &new_key, &new_child); if (rc == IndexRc::kRetry) { retry = true; } else if (rc != IndexRc::kOk) { @@ -489,12 +508,50 @@ Status BPlusTree::DoInsert(const key_type &key, const value_type &va } } } while (retry); - (void) stats_.size_++; + (void)stats_.size_++; return Status::OK(); } -template -void BPlusTree::PopulateNumKeys() { +template +Status BPlusTree::DoInsert(const key_type &key, const value_type &value) { + // We don't store the value directly into the leaf node as it is expensive to move it during node split. + // Rather we store a pointer instead. + return DoInsert(key, std::make_unique(value)); +} + +template +std::unique_ptr BPlusTree::DoUpdate(const key_type &key, const value_type &new_value) { + return DoUpdate(key, std::make_unique(new_value)); +} + +template +std::unique_ptr BPlusTree::DoUpdate(const key_type &key, std::unique_ptr &&new_value) { + if (root_ != nullptr) { + LeafNode *leaf = nullptr; + slot_type slot; + RWLock *myLock = &this->rw_lock_; + // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. + myLock->LockShared(); + IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot); + if (rc == IndexRc::kOk) { + // All locks from the tree to the parent of leaf are all gone. We still have a X lock + // on the leaf. + // Swap out the old value and replace it with new value. + std::unique_ptr old = std::move(leaf->data_[leaf->slot_dir_[slot]]); + leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value); + leaf->rw_lock_.Unlock(); + return old; + } else { + MS_LOG(INFO) << "Key not found. rc = " << static_cast(rc) << "."; + return nullptr; + } + } else { + return nullptr; + } +} + +template +void BPlusTree::PopulateNumKeys() { // Start from the root and we calculate how many leaf nodes as pointed to by each inner node. // The results are stored in the numKeys array in each inner node. (void)PopulateNumKeys(root_); @@ -502,8 +559,8 @@ void BPlusTree::PopulateNumKeys() { stats_.num_keys_array_valid_ = true; } -template -uint64_t BPlusTree::PopulateNumKeys(BPlusTree::BaseNode *n) { +template +uint64_t BPlusTree::PopulateNumKeys(BPlusTree::BaseNode *n) { if (n->is_leafnode()) { auto *leaf = static_cast(n); return leaf->slotuse_; @@ -518,8 +575,8 @@ uint64_t BPlusTree::PopulateNumKeys(BPlusTree::BaseNode } } -template -typename BPlusTree::key_type BPlusTree::KeyAtPos(uint64_t inx) { +template +typename BPlusTree::key_type BPlusTree::KeyAtPos(uint64_t inx) { if (stats_.num_keys_array_valid_ == false) { // We need exclusive access to the tree. If concurrent insert is going on, it is hard to get accurate numbers UniqueLock lck(&rw_lock_); @@ -532,8 +589,9 @@ typename BPlusTree::key_type BPlusTree::KeyAtPos(uint64_ return KeyAtPos(root_, inx); } -template -typename BPlusTree::key_type BPlusTree::KeyAtPos(BPlusTree::BaseNode *n, uint64_t inx) { +template +typename BPlusTree::key_type BPlusTree::KeyAtPos(BPlusTree::BaseNode *n, + uint64_t inx) { if (n->is_leafnode()) { auto *leaf = static_cast(n); return leaf->keys_[leaf->slot_dir_[inx]]; @@ -546,7 +604,7 @@ typename BPlusTree::key_type BPlusTree::KeyAtPos(BPlusTr } for (auto i = 0; i < inner->slotuse_; i++) { if ((inx + 1) > inner->num_keys_[inner->slot_dir_[i] + 1]) { - inx -= inner->num_keys_[inner->slot_dir_[i]+1]; + inx -= inner->num_keys_[inner->slot_dir_[i] + 1]; } else { return KeyAtPos(inner->data_[inner->slot_dir_[i] + 1], inx); } diff --git a/mindspore/ccsrc/dataset/util/btree_iterator.tpp b/mindspore/ccsrc/dataset/util/btree_iterator.tpp index 5677b581bb98cbeb33d80e9dbf5c90c5b6def1d0..b7d6b09fda2a7fc4f87db3198138754e9bedbdc3 100644 --- a/mindspore/ccsrc/dataset/util/btree_iterator.tpp +++ b/mindspore/ccsrc/dataset/util/btree_iterator.tpp @@ -21,11 +21,23 @@ namespace mindspore { namespace dataset { -template -typename BPlusTree::Iterator &BPlusTree::Iterator::operator++() { +template +BPlusTree::Iterator::~Iterator() { + if (locked_) { + cur_->rw_lock_.Unlock(); + locked_ = false; + } +} + +template +typename BPlusTree::Iterator &BPlusTree::Iterator::operator++() { if (slot_ + 1u < cur_->slotuse_) { ++slot_; } else if (cur_->link_.next) { + if (locked_) { + cur_->link_.next->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.next; slot_ = 0; } else { @@ -34,12 +46,16 @@ typename BPlusTree::Iterator &BPlusTree::Iterator::opera return *this; } -template -typename BPlusTree::Iterator BPlusTree::Iterator::operator++(int) { +template +typename BPlusTree::Iterator BPlusTree::Iterator::operator++(int) { Iterator tmp = *this; if (slot_ + 1u < cur_->slotuse_) { ++slot_; } else if (cur_->link_.next) { + if (locked_) { + cur_->link_.next->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.next; slot_ = 0; } else { @@ -48,11 +64,15 @@ typename BPlusTree::Iterator BPlusTree::Iterator::operat return tmp; } -template -typename BPlusTree::Iterator &BPlusTree::Iterator::operator--() { +template +typename BPlusTree::Iterator &BPlusTree::Iterator::operator--() { if (slot_ > 0) { --slot_; } else if (cur_->link_.prev) { + if (locked_) { + cur_->link_.prev->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.prev; slot_ = cur_->slotuse_ - 1; } else { @@ -61,12 +81,16 @@ typename BPlusTree::Iterator &BPlusTree::Iterator::opera return *this; } -template -typename BPlusTree::Iterator BPlusTree::Iterator::operator--(int) { +template +typename BPlusTree::Iterator BPlusTree::Iterator::operator--(int) { Iterator tmp = *this; if (slot_ > 0) { --slot_; } else if (cur_->link_.prev) { + if (locked_) { + cur_->link_.prev->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.prev; slot_ = cur_->slotuse_ - 1; } else { @@ -75,11 +99,77 @@ typename BPlusTree::Iterator BPlusTree::Iterator::operat return tmp; } -template -typename BPlusTree::ConstIterator &BPlusTree::ConstIterator::operator++() { +template +BPlusTree::Iterator::Iterator(const BPlusTree::Iterator &lhs) { + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + if (this->locked_) { + this->cur_->rw_lock_.LockShared(); + } +} + +template +BPlusTree::Iterator::Iterator(BPlusTree::Iterator &&lhs) { + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + lhs.locked_ = false; + lhs.slot_ = 0; + lhs.cur_ = nullptr; +} + +template +typename BPlusTree::Iterator &BPlusTree::Iterator::operator=( + const BPlusTree::Iterator &lhs) { + if (*this != lhs) { + if (this->locked_) { + this->cur_->rw_lock_.Unlock(); + } + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + if (this->locked_) { + this->cur_->rw_lock_.LockShared(); + } + } + return *this; +} + +template +typename BPlusTree::Iterator &BPlusTree::Iterator::operator=( + BPlusTree::Iterator &&lhs) { + if (*this != lhs) { + if (this->locked_) { + this->cur_->rw_lock_.Unlock(); + } + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + lhs.locked_ = false; + lhs.slot_ = 0; + lhs.cur_ = nullptr; + } + return *this; +} + +template +BPlusTree::ConstIterator::~ConstIterator() { + if (locked_) { + cur_->rw_lock_.Unlock(); + locked_ = false; + } +} + +template +typename BPlusTree::ConstIterator &BPlusTree::ConstIterator::operator++() { if (slot_ + 1u < cur_->slotuse_) { ++slot_; } else if (cur_->link_.next) { + if (locked_) { + cur_->link_.next->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.next; slot_ = 0; } else { @@ -88,12 +178,16 @@ typename BPlusTree::ConstIterator &BPlusTree::ConstItera return *this; } -template -typename BPlusTree::ConstIterator BPlusTree::ConstIterator::operator++(int) { +template +typename BPlusTree::ConstIterator BPlusTree::ConstIterator::operator++(int) { Iterator tmp = *this; if (slot_ + 1u < cur_->slotuse_) { ++slot_; } else if (cur_->link_.next) { + if (locked_) { + cur_->link_.next->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.next; slot_ = 0; } else { @@ -102,11 +196,15 @@ typename BPlusTree::ConstIterator BPlusTree::ConstIterat return tmp; } -template -typename BPlusTree::ConstIterator &BPlusTree::ConstIterator::operator--() { +template +typename BPlusTree::ConstIterator &BPlusTree::ConstIterator::operator--() { if (slot_ > 0) { --slot_; } else if (cur_->link_.prev) { + if (locked_) { + cur_->link_.prev->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.prev; slot_ = cur_->slotuse_ - 1; } else { @@ -115,12 +213,16 @@ typename BPlusTree::ConstIterator &BPlusTree::ConstItera return *this; } -template -typename BPlusTree::ConstIterator BPlusTree::ConstIterator::operator--(int) { +template +typename BPlusTree::ConstIterator BPlusTree::ConstIterator::operator--(int) { Iterator tmp = *this; if (slot_ > 0) { --slot_; } else if (cur_->link_.prev) { + if (locked_) { + cur_->link_.prev->rw_lock_.LockShared(); + cur_->rw_lock_.Unlock(); + } cur_ = cur_->link_.prev; slot_ = cur_->slotuse_ - 1; } else { @@ -129,14 +231,95 @@ typename BPlusTree::ConstIterator BPlusTree::ConstIterat return tmp; } -template -typename BPlusTree::ConstIterator BPlusTree::Search(const key_type &key) const { +template +BPlusTree::ConstIterator::ConstIterator(const BPlusTree::ConstIterator &lhs) { + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + if (this->locked_) { + this->cur_->rw_lock_.LockShared(); + } +} + +template +BPlusTree::ConstIterator::ConstIterator(BPlusTree::ConstIterator &&lhs) { + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + lhs.locked_ = false; + lhs.slot_ = 0; + lhs.cur_ = nullptr; +} + +template +typename BPlusTree::ConstIterator &BPlusTree::ConstIterator::operator=( + const BPlusTree::ConstIterator &lhs) { + if (*this != lhs) { + if (this->locked_) { + this->cur_->rw_lock_.Unlock(); + } + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + if (this->locked_) { + this->cur_->rw_lock_.LockShared(); + } + } + return *this; +} + +template +typename BPlusTree::ConstIterator &BPlusTree::ConstIterator::operator=( + BPlusTree::ConstIterator &&lhs) { + if (*this != lhs) { + if (this->locked_) { + this->cur_->rw_lock_.Unlock(); + } + this->cur_ = lhs.cur_; + this->slot_ = lhs.slot_; + this->locked_ = lhs.locked_; + lhs.locked_ = false; + lhs.slot_ = 0; + lhs.cur_ = nullptr; + } + return *this; +} + +template +typename BPlusTree::ConstIterator BPlusTree::Search(const key_type &key) const { + if (root_ != nullptr) { + LeafNode *leaf = nullptr; + slot_type slot; + RWLock *myLock = &this->rw_lock_; + // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. + myLock->LockShared(); + IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); + if (rc == IndexRc::kOk) { + // All locks from the tree to the parent of leaf are all gone. We still have a S lock + // on the leaf. The unlock will be handled by the iterator when it goes out of scope. + return ConstIterator(leaf, slot, true); + } else { + MS_LOG(INFO) << "Key not found. rc = " << static_cast(rc) << "."; + return cend(); + } + } else { + return cend(); + } +} + +template +typename BPlusTree::Iterator BPlusTree::Search(const key_type &key) { if (root_ != nullptr) { LeafNode *leaf = nullptr; slot_type slot; - IndexRc rc = Locate(root_, key, &leaf, &slot); + RWLock *myLock = &this->rw_lock_; + // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. + myLock->LockShared(); + IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); if (rc == IndexRc::kOk) { - return ConstIterator(leaf, slot); + // All locks from the tree to the parent of leaf are all gone. We still have a S lock + // on the leaf. The unlock will be handled by the iterator when it goes out of scope. + return Iterator(leaf, slot, true); } else { MS_LOG(INFO) << "Key not found. rc = " << static_cast(rc) << "."; return end(); @@ -146,39 +329,39 @@ typename BPlusTree::ConstIterator BPlusTree::Search(cons } } -template -typename BPlusTree::value_type BPlusTree::operator[](key_type key) { - ConstIterator it = Search(key); +template +typename BPlusTree::value_type BPlusTree::operator[](key_type key) { + Iterator it = Search(key); return it.value(); } -template -typename BPlusTree::Iterator BPlusTree::begin() { +template +typename BPlusTree::Iterator BPlusTree::begin() { return Iterator(this); } -template -typename BPlusTree::Iterator BPlusTree::end() { +template +typename BPlusTree::Iterator BPlusTree::end() { return Iterator(this->leaf_nodes_.tail, this->leaf_nodes_.tail ? this->leaf_nodes_.tail->slotuse_ : 0); } -template -typename BPlusTree::ConstIterator BPlusTree::begin() const { +template +typename BPlusTree::ConstIterator BPlusTree::begin() const { return ConstIterator(this); } -template -typename BPlusTree::ConstIterator BPlusTree::end() const { +template +typename BPlusTree::ConstIterator BPlusTree::end() const { return ConstIterator(this->leaf_nodes_.tail, this->leaf_nodes_.tail ? this->leaf_nodes_.tail->slotuse_ : 0); } -template -typename BPlusTree::ConstIterator BPlusTree::cbegin() const { +template +typename BPlusTree::ConstIterator BPlusTree::cbegin() const { return ConstIterator(this); } -template -typename BPlusTree::ConstIterator BPlusTree::cend() const { +template +typename BPlusTree::ConstIterator BPlusTree::cend() const { return ConstIterator(this->leaf_nodes_.tail, this->leaf_nodes_.tail ? this->leaf_nodes_.tail->slotuse_ : 0); } } // namespace dataset diff --git a/tests/ut/cpp/dataset/btree_test.cc b/tests/ut/cpp/dataset/btree_test.cc index 3e0a867fba8a1b1d6906b15caa9f64f5714baf95..2e40f4a6618cbb9f477b62f5e7047b8f53ac3243 100644 --- a/tests/ut/cpp/dataset/btree_test.cc +++ b/tests/ut/cpp/dataset/btree_test.cc @@ -50,7 +50,7 @@ class MindDataTestBPlusTree : public UT::Common { // Test serial insert. TEST_F(MindDataTestBPlusTree, Test1) { Allocator alloc(std::make_shared()); - BPlusTree, mytraits> btree(alloc); + BPlusTree, std::less, mytraits> btree(alloc); Status rc; for (int i = 0; i < 100; i++) { uint64_t key = 2 * i; @@ -92,16 +92,16 @@ TEST_F(MindDataTestBPlusTree, Test1) { } } - // Test nearch + // Test search { MS_LOG(INFO) << "Locate key " << 100 << " Expect found."; auto it = btree.Search(100); - EXPECT_FALSE(it == btree.cend()); + EXPECT_FALSE(it == btree.end()); EXPECT_EQ(it.key(), 100); EXPECT_EQ(it.value(), "Hello World. I am 100"); MS_LOG(INFO) << "Locate key " << 300 << " Expect not found."; it = btree.Search(300); - EXPECT_TRUE(it == btree.cend()); + EXPECT_TRUE(it == btree.end()); } // Test duplicate key @@ -114,7 +114,7 @@ TEST_F(MindDataTestBPlusTree, Test1) { // Test concurrent insert. TEST_F(MindDataTestBPlusTree, Test2) { Allocator alloc(std::make_shared()); - BPlusTree, mytraits> btree(alloc); + BPlusTree, std::less, mytraits> btree(alloc); TaskGroup vg; auto f = [&](int k) -> Status { TaskManager::FindMe()->Post(); @@ -127,10 +127,22 @@ TEST_F(MindDataTestBPlusTree, Test2) { } return Status::OK(); }; - // Spawn two threads. One insert the odd numbers and the other insert the even numbers just like Test1 + auto g = [&](int k) -> Status { + TaskManager::FindMe()->Post(); + for (int i = 0; i < 1000; i++) { + uint64_t key = rand() % 10000;; + auto it = btree.Search(key); + } + return Status::OK(); + }; + // Spawn multiple threads to do insert. for (int k = 0; k < 100; k++) { vg.CreateAsyncTask("Concurrent Insert", std::bind(f, k)); } + // Spawn a few threads to do random search. + for (int k = 0; k < 2; k++) { + vg.CreateAsyncTask("Concurrent search", std::bind(g, k)); + } vg.join_all(); EXPECT_EQ(btree.size(), 10000); @@ -158,7 +170,7 @@ TEST_F(MindDataTestBPlusTree, Test2) { MS_LOG(INFO) << "Locating key from 0 to 9999. Expect found."; for (int i = 0; i < 10000; i++) { auto it = btree.Search(i); - bool eoS = (it == btree.cend()); + bool eoS = (it == btree.end()); EXPECT_FALSE(eoS); if (!eoS) { EXPECT_EQ(it.key(), i); @@ -168,7 +180,7 @@ TEST_F(MindDataTestBPlusTree, Test2) { } MS_LOG(INFO) << "Locate key " << 10000 << ". Expect not found"; auto it = btree.Search(10000); - EXPECT_TRUE(it == btree.cend()); + EXPECT_TRUE(it == btree.end()); } // Test to retrieve key at certain position. @@ -182,11 +194,11 @@ TEST_F(MindDataTestBPlusTree, Test2) { TEST_F(MindDataTestBPlusTree, Test3) { Allocator alloc(std::make_shared()); - AutoIndexObj ai(alloc); + AutoIndexObj> ai(alloc); Status rc; rc = ai.insert("Hello World"); EXPECT_TRUE(rc.IsOk()); - ai.insert({"a", "b", "c"}); + rc = ai.insert({"a", "b", "c"}); EXPECT_TRUE(rc.IsOk()); uint64_t min = ai.min_key(); uint64_t max = ai.max_key(); @@ -199,3 +211,30 @@ TEST_F(MindDataTestBPlusTree, Test3) { MS_LOG(DEBUG) << ai[i] << std::endl; } } + +TEST_F(MindDataTestBPlusTree, Test4) { + Allocator alloc(std::make_shared()); + AutoIndexObj> ai(alloc); + Status rc; + for (int i = 0; i < 1000; i++) { + rc = ai.insert(std::make_unique(i)); + EXPECT_TRUE(rc.IsOk()); + } + // Test iterator + { + int cnt = 0; + auto it = ai.begin(); + uint64_t prev = it.key(); + ++it; + ++cnt; + while (it != ai.end()) { + uint64_t cur = it.key(); + EXPECT_TRUE(prev < cur); + EXPECT_EQ(it.value(), cnt); + prev = cur; + ++it; + ++cnt; + } + EXPECT_EQ(cnt, 1000); + } +}