diff --git a/akg b/akg index c460176523d039c8995f1d71089753725ebc0792..df57a6cf9450e347d1854687d1fe66a420ee3b35 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit c460176523d039c8995f1d71089753725ebc0792 +Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35 diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 8109e608c5c1200f99b3f7a337287d583c662aeb..cc5845cbf15bcad0d79acbffff7f9ccd3fa9557d 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -277,10 +277,11 @@ endif () if (USE_GLOG) target_link_libraries(inference PRIVATE mindspore::glog) -else() - if (CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init) - elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") - set_target_properties(inference PROPERTIES MACOSX_RPATH ON) - endif () endif() + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(inference PRIVATE -Wl,-init,common_log_init) +elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") + set_target_properties(inference PROPERTIES MACOSX_RPATH ON) +endif () + diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc index b820779ed1a70f9c21309d8468fd331232b502ca..fbaf2c9326d821e968967e35fb4da2f60565a7aa 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc @@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow BOUNDING_BOX_CHECK(input); CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); - (*output).push_back(nullptr); // init memory for return vector - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); // move boxes over to output size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc index 2be37f1da365c6740c6848bd274342f489112699..c873307afdd5ce08a323094c78ec0dde7e94b4f3 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc @@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) int32_t padded_image_h; int32_t padded_image_w; - (*output).push_back(nullptr); - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); // since some boxes may be removed bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc index c6aa8450a8dd92dc993f0df14333b4a029841b45..ffea851eac1f0856e3fad13296be24682e1958d4 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc @@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow * RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); } - (*output).push_back(nullptr); - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); return VerticalFlip(input[0], &(*output)[0]); diff --git a/mindspore/ccsrc/dataset/util/CMakeLists.txt b/mindspore/ccsrc/dataset/util/CMakeLists.txt index b36d612435aba228500cda25ea239daaf5eb424a..96489add071f2d8072eadfc4eac26a67f5c3ed2a 100644 --- a/mindspore/ccsrc/dataset/util/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/util/CMakeLists.txt @@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(utils OBJECT arena.cc + buddy.cc + cache_pool.cc circular_pool.cc memory_pool.cc cond_var.cc @@ -11,7 +13,11 @@ add_library(utils OBJECT service.cc services.cc lock.cc + semaphore.cc status.cc + storage_container.cc + storage_manager.cc + slice.cc path.cc wait_post.cc sig_handler.cc) diff --git a/mindspore/ccsrc/dataset/util/allocator.h b/mindspore/ccsrc/dataset/util/allocator.h index ba6c7786df50ca3a19fd8694735a58c63406df0c..50a9cadbe3fb58cfe5a7e9f10e3c6ba104bbcdb5 100644 --- a/mindspore/ccsrc/dataset/util/allocator.h +++ b/mindspore/ccsrc/dataset/util/allocator.h @@ -17,8 +17,10 @@ #define DATASET_UTIL_ALLOCATOR_H_ #include +#include #include #include +#include #include "dataset/util/memory_pool.h" namespace mindspore { @@ -84,6 +86,91 @@ class Allocator { private: std::shared_ptr pool_; }; +/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will +/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator. +/// Default to std::allocator +template > +class MemGuard { + public: + using allocator = C; + MemGuard() : n_(0) {} + explicit MemGuard(allocator a) : n_(0), alloc_(a) {} + // There is no copy constructor nor assignment operator because the memory is solely owned by this object. + MemGuard(const MemGuard &) = delete; + MemGuard &operator=(const MemGuard &) = delete; + // On the other hand, We can support move constructor + MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} + MemGuard &operator=(MemGuard &&lhs) noexcept { + if (this != &lhs) { + this->deallocate(); + n_ = lhs.n_; + alloc_ = std::move(lhs.alloc_); + ptr_ = std::move(lhs.ptr_); + } + return *this; + } + /// \brief Explicitly deallocate the memory if allocated + void deallocate() { + if (ptr_) { + auto *p = ptr_.release(); + if (!std::is_arithmetic::value && std::is_destructible::value) { + for (auto i = 0; i < n_; ++i) { + p[i].~T(); + } + } + alloc_.deallocate(p, n_); + n_ = 0; + } + } + /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is + /// allocated. + /// \param n Number of objects of type T to be allocated + /// \tparam Args Extra arguments pass to the constructor of T + template + Status allocate(size_t n, Args &&... args) noexcept { + try { + deallocate(); + if (n > 0) { + T *data = alloc_.allocate(n); + if (!std::is_arithmetic::value) { + for (auto i = 0; i < n; i++) { + std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); + } + } + ptr_ = std::unique_ptr(data); + n_ = n; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); + } + ~MemGuard() noexcept { deallocate(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetPointer() const { return ptr_.get(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetMutablePointer() { return ptr_.get(); } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) { return GetMutablePointer() + x; } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) const { return GetPointer() + x; } + /// \brief Return how many bytes are allocated in total + /// \return Number of bytes allocated in total + size_t GetSizeInBytes() const { return n_ * sizeof(T); } + + private: + allocator alloc_; + std::unique_ptr> ptr_; + size_t n_; +}; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/auto_index.h b/mindspore/ccsrc/dataset/util/auto_index.h index 11a2e90b00d06ac15911e23f23860df22d460ee5..5c43ecfd80b776b2c5da52d08b75129918f297fd 100644 --- a/mindspore/ccsrc/dataset/util/auto_index.h +++ b/mindspore/ccsrc/dataset/util/auto_index.h @@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree { } private: - static constexpr key_type kMinKey = 1; + static constexpr key_type kMinKey = 0; std::atomic inx_; }; } // namespace dataset diff --git a/mindspore/ccsrc/dataset/util/buddy.cc b/mindspore/ccsrc/dataset/util/buddy.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a14258419a7225a52a61a4f60e2027126c693f9 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/buddy.cc @@ -0,0 +1,388 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/buddy.h" +#include +#include +#include "dataset/util/de_error.h" +#include "dataset/util/memory_pool.h" +#include "dataset/util/system_pool.h" +#include "./securec.h" + +inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } + +inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } + +inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } + +inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } + +inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } + +namespace mindspore { +namespace dataset { +Status BuddySpace::Init() { + if (log_min_ < 0) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "log_min must be positive : " + std::to_string(log_min_)); + } + if (num_lvl_ < 3 || num_lvl_ > 18) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); + } + min_ = BitLeftShift(1, log_min_); + max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); + size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; + size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; + size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; + RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); + hint_ = reinterpret_cast(ptr_); + count_ = reinterpret_cast((reinterpret_cast(ptr_) + offset_1)); + map_ = reinterpret_cast(ptr_) + offset_2; + count_[num_lvl_ - 1] = 1; + map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); + return Status::OK(); +} + +Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { + std::lock_guard lock(mutex_); + addr_t addr = AllocNoLock(sz, desc); + if (addr != NOSPACE) { + *p = addr; + return Status::OK(); + } else { + return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); + } +} + +addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { + DS_ASSERT(sz <= max_); + uint32_t reqSize = SizeToBlock(sz); + rel_addr_t rel_addr = AllocBuddySeg(reqSize); + if (rel_addr != static_cast(NOSPACE)) { + (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); + desc->sig = static_cast(0xDEADBEEF); + desc->addr = rel_addr; + desc->req_size = reqSize; + desc->blk_size = NextPowerOf2(reqSize); + return static_cast(rel_addr * min_); + } else { + return NOSPACE; + } +} + +void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { + DS_ASSERT(desc->sig == 0XDEADBEEF); + rel_addr_t rel_addr = desc->addr; + size_t blk_size = desc->blk_size; + size_t req_size = desc->req_size; + FreeBuddySeg(rel_addr, blk_size, req_size); +} + +void BuddySpace::Free(const BSpaceDescriptor *desc) { + std::lock_guard lock(mutex_); + return FreeNoLock(desc); +} + +std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { + os << "1 unit = " << s.GetMinSize() << "\n" + << "Size of buddy space = " << s.GetMaxSize() << "\n" + << "Number of levels = " << s.num_lvl_ << "\n\n" + << "Percent free = " << s.PercentFree() << "\n" + << "Dumping count array : " + << "\n"; + for (int i = 0; i < s.num_lvl_; i++) { + os << "[" << i << "] = " << s.count_[i] << " "; + if (((i + 1) % 4) == 0) { + os << "\n"; + } + } + os << "\n"; + os << "Dumping allocation info:" + << "\n"; + auto max_addr = static_cast(BitLeftShift(1, s.num_lvl_ - 1)); + rel_addr_t addr = 0; + while (addr < max_addr) { + size_t sz = 0; + BuddySpace::STATE st; + s.GetBuddySegState(addr, &sz, &st); + os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " + << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) + << "\n"; + addr += sz; + } + return os; +} + +void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { + char byte; + int pos; + int offset; + uint64_t val = 0; + int shift; + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + byte = map_[pos]; + switch (offset) { + case 0: + val = byte; + break; + case 1: + case 3: + if (offset == 1) { + val = BitLeftShift(BitAnd(byte, 0x30), shift); + } else { + val = BitLeftShift(BitAnd(byte, 0x03), shift); + } + break; + case 2: + val = BitLeftShift(BitAnd(byte, 0x0F), shift); + break; + } + if (BitAnd(val, ONE_BIT)) { + *rel_sz = 1; + } else if (BitAnd(val, TWO_BIT)) { + *rel_sz = 2; + } else if (BitAnd(val, MORE_BIT)) { + log_t lg = BitAnd(val, 0x0F); + *rel_sz = BitLeftShift(1, lg + 2); + } else { + *st = STATE::kEmpty; + return; + } + *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; +} + +void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { + int clr; + int mask; + int pos; + int offset; + int val = 0; + int shift; + auto log_sz = static_cast(Log2(rel_sz)); + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + if (rel_sz == 1) { + val = ONE_BIT; + mask = 0xC0; + } else if (rel_sz == 2) { + val = TWO_BIT; + mask = 0xF0; + } else { + val = BitOr(log_sz - 2, MORE_BIT); + mask = 0xFF; + } + if (st == STATE::kAlloc) { + val = BitOr(val, ALLOC_BIT); + } else if (st == STATE::kFree) { + val = BitAnd(val, ~(static_cast(ALLOC_BIT))); + } else if (st == STATE::kEmpty) { + val = 0; + } + clr = static_cast(~(BitRightShift(mask, shift))); + map_[pos] = static_cast(BitAnd(map_[pos], clr)); + map_[pos] = static_cast(BitOr(map_[pos], BitRightShift(val, shift))); + if (st == STATE::kAlloc) { + count_[log_sz]--; + } else if (st == STATE::kFree) { + count_[log_sz]++; + if (rel_addr < hint_[log_sz]) { + hint_[log_sz] = rel_addr; + } + } +} + +void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { + while (blk_sz < BitLeftShift(1, num_lvl_)) { + rel_addr_t buddy = BitEx(addr, blk_sz); + size_t sz = 0; + STATE st; + GetBuddySegState(buddy, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + auto log_sz = static_cast(Log2(blk_sz)); + rel_addr_t left = (buddy < addr) ? buddy : addr; + rel_addr_t right = left + blk_sz; + DS_ASSERT(count_[log_sz] >= 2); + count_[log_sz] -= 2; + SetBuddySegState(right, blk_sz, STATE::kEmpty); + SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); + for (int i = 0; i < log_sz; i++) { + if (hint_[i] == right) { + hint_[i] = left; + } + } + addr = left; + blk_sz <<= 1u; + } else { + break; + } + } +} + +void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + DS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + count_[i]--; + SetBuddySegState(addr, half_sz, STATE::kFree); + SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); + if (remaining_sz >= half_sz) { + SetBuddySegState(addr, half_sz, STATE::kAlloc); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + break; + } + addr += half_sz; + } + } +} + +void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + DS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + if (remaining_sz >= half_sz) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + DS_ASSERT(sz == half_sz && st == STATE::kAlloc); + } +#endif + SetBuddySegState(addr, half_sz, STATE::kFree); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + JoinBuddySeg(addr, half_sz); + break; + } + addr += half_sz; + } + } +} + +rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { + uint32_t blk_size = NextPowerOf2(req_size); + int start_inx = static_cast(Log2(blk_size)); + bool found = false; + rel_addr_t ask_addr = 0; + auto max_addr = static_cast(BitLeftShift(1, num_lvl_ - 1)); + STATE st; + size_t sz = 0; + for (int i = start_inx; !found && i < num_lvl_; i++) { + DS_ASSERT(count_[i] >= 0); + if (count_[i] == 0) { + continue; + } + auto blk_sz = static_cast(BitLeftShift(1, i)); + ask_addr = hint_[i]; + while (ask_addr < max_addr && !found) { + GetBuddySegState(ask_addr, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + found = true; + } else { + DS_ASSERT(st != STATE::kEmpty); + ask_addr += ((sz > blk_sz) ? sz : blk_sz); + } + } + } + if (found) { + if (sz > req_size) { + TrimBuddySeg(ask_addr, sz, req_size); + } else { + SetBuddySegState(ask_addr, sz, STATE::kAlloc); + hint_[start_inx] = ask_addr; + } + return ask_addr; + } else { + return static_cast(NOSPACE); + } +} + +void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { + if (req_size == blk_size) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + } +#endif + SetBuddySegState(addr, blk_size, STATE::kFree); + JoinBuddySeg(addr, blk_size); + } else { + UnTrimBuddySeg(addr, blk_size, req_size); + } +} + +int BuddySpace::PercentFree() const { + uint64_t total_free_sz = 0; + uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); + // Go through the count array without lock + for (int i = 0; i < num_lvl_; i++) { + int cnt = count_[i]; + if (cnt == 0) { + continue; + } + uint64_t blk_sz = BitLeftShift(1, i); + total_free_sz += (blk_sz * cnt); + } + return static_cast(static_cast(total_free_sz) / static_cast(max_sz_in_unit) * 100); +} + +BuddySpace::BuddySpace(int log_min, int num_lvl) + : hint_(nullptr), + count_(nullptr), + map_(nullptr), + log_min_(log_min), + num_lvl_(num_lvl), + min_(0), + max_(0), + ptr_(nullptr) {} + +BuddySpace::~BuddySpace() { + if (ptr_ != nullptr) { + free(ptr_); + } + hint_ = nullptr; + count_ = nullptr; + map_ = nullptr; +} + +Status BuddySpace::CreateBuddySpace(std::unique_ptr *out_bs, int log_min, int num_lvl) { + Status rc; + auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); + if (bs == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = bs->Init(); + if (rc.IsOk()) { + (*out_bs).reset(bs); + } else { + delete bs; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/buddy.h b/mindspore/ccsrc/dataset/util/buddy.h new file mode 100644 index 0000000000000000000000000000000000000000..08c05cbbdbe3808a944c95e9297ce9a3f78d185c --- /dev/null +++ b/mindspore/ccsrc/dataset/util/buddy.h @@ -0,0 +1,133 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_BUDDY_H_ +#define DATASET_UTIL_BUDDY_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/util/status.h" + +using addr_t = int64_t; +using rel_addr_t = int32_t; +using log_t = int; +#define ALLOC_BIT 0x80 +#define ONE_BIT 0x40 +#define TWO_BIT 0x20 +#define MORE_BIT 0x10 +#define NOSPACE ((addr_t)(-1)) +namespace mindspore { +namespace dataset { +struct BSpaceDescriptor { + int32_t sig; + rel_addr_t addr; + size_t req_size; + size_t blk_size; +}; + +class BuddySpace { + public: + // C++11 feature. Change STATE into a type safe class with + // the keyword. Don't take out the keyword 'class' + enum class STATE { kFree, kAlloc, kEmpty }; + + BuddySpace(const BuddySpace &) = delete; + + BuddySpace &operator=(const BuddySpace &) = delete; + + virtual ~BuddySpace(); + + Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; + + void Free(const BSpaceDescriptor *desc); + + uint64_t GetMinSize() const { return min_; } + + uint64_t GetMaxSize() const { return max_; } + + int PercentFree() const; + + friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); + + static uint64_t NextPowerOf2(uint64_t n) { + if (n <= 1) { + return 1; + } + n = n - 1; + while (n & (n - 1)) { + n = n & (n - 1); + } + return n << 1; + } + + static uint32_t Log2(uint64_t n) { + uint32_t cnt = 0; + while (n >>= 1) { + cnt++; + } + return cnt; + } + + static Status CreateBuddySpace(std::unique_ptr *out_bs, int log_min = 15, int num_lvl = 18); + + private: + rel_addr_t *hint_; + int *count_; + char *map_; + int log_min_; + int num_lvl_; + uint64_t min_; + uint64_t max_; + void *ptr_; + std::mutex mutex_; + + explicit BuddySpace(int log_min = 15, int num_lvl = 18); + + Status Init(); + + addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; + + void FreeNoLock(const BSpaceDescriptor *desc); + + uint32_t SizeToBlock(const uint64_t sz) const { + uint32_t reqSize = (sz / min_); + if (sz % min_) { + reqSize++; + } + return reqSize; + } + + void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; + + void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); + + void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); + + void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; + + void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/dataset/util/cache_pool.cc b/mindspore/ccsrc/dataset/util/cache_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..92504cd06344e1878cc888639089b6ddd904764d --- /dev/null +++ b/mindspore/ccsrc/dataset/util/cache_pool.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/utils.h" +#include "dataset/util/cache_pool.h" +#include "dataset/util/services.h" + +namespace mindspore { +namespace dataset { +CachePool::CachePool(const value_allocator &alloc, const std::string &root) + : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} + +Status CachePool::DoServiceStart() { + tree_ = std::make_shared(); + // If we are given a disk path, set up the StorageManager + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + sm_ = std::make_shared(spill); + RETURN_IF_NOT_OK(sm_->ServiceStart()); + MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); + } + return Status::OK(); +} +Status CachePool::DoServiceStop() { + Status rc; + Status rc2; + if (sm_ != nullptr) { + rc = sm_->ServiceStop(); + if (rc.IsError()) { + rc2 = rc; + } + } + sm_.reset(); + for (auto &bl : *tree_) { + if (bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, bl.sz); + } + } + tree_.reset(); + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + auto it = Path::DirIterator::OpenDirectory(&spill); + while (it->hasNext()) { + rc = it->next().Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + rc = spill.Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + return rc2; +} +CachePool::~CachePool() noexcept { (void)ServiceStop(); } +Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { + DataLocator bl; + Status rc; + size_t sz = 0; + // We will consolidate all the slices into one piece. + for (auto &v : buf) { + sz += v.GetSize(); + } + bl.sz = sz; + try { + bl.ptr = alloc_.allocate(sz); + // We will do a piecewise copy. + WritableSlice dest(bl.ptr, bl.sz); + size_t pos = 0; + for (auto &v : buf) { + WritableSlice out(dest, pos); + rc = WritableSlice::Copy(&out, v); + if (rc.IsError()) { + break; + } + pos += v.GetSize(); + } + if (rc.IsError()) { + alloc_.deallocate(bl.ptr, sz); + bl.ptr = nullptr; + return rc; + } + } catch (std::bad_alloc &e) { + if (sm_ != nullptr) { + RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); + // We have an assumption 0 is not a valid key from the design of AutoIndexObj. + // Make sure it is not 0. + if (bl.storage_key == 0) { + RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected"); + } + } else { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + rc = tree_->insert(bl, key); + if (rc.IsError() && bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, sz); + } + return rc; +} +Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + if (it->ptr != nullptr) { + ReadableSlice src(it->ptr, it->sz); + RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); + } else if (sm_ != nullptr) { + size_t expectedLength = 0; + RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); + if (expectedLength != it->sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + if (bytesRead != nullptr) { + *bytesRead = it->sz; + } + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} +const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } +Path CachePool::GetSpillPath() const { + auto spill = Path(root_) / subfolder_; + return spill; +} +CachePool::CacheStat CachePool::GetStat() const { + CacheStat cs{0}; + for (auto &it : *tree_) { + if (it.ptr != nullptr) { + ++cs.num_mem_cached; + } else { + ++cs.num_disk_cached; + } + } + return cs; +} +Status CachePool::Spill(CachePool::DataLocator *dl) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to spill"); + } + RETURN_UNEXPECTED_IF_NULL(dl); + RETURN_UNEXPECTED_IF_NULL(dl->ptr); + if (dl->storage_key == 0) { + ReadableSlice data(dl->ptr, dl->sz); + RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); + } + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return Status::OK(); +} +Status CachePool::Locate(CachePool::DataLocator *dl) { + RETURN_UNEXPECTED_IF_NULL(dl); + if (dl->ptr == nullptr) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); + } + try { + dl->ptr = alloc_.allocate(dl->sz); + WritableSlice dest(dl->ptr, dl->sz); + Status rc = Read(dl->storage_key, &dest); + if (rc.IsError()) { + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return rc; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + return Status::OK(); +} +size_t CachePool::GetSize(CachePool::key_type key) const { + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + return it->sz; + } else { + return 0; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/cache_pool.h b/mindspore/ccsrc/dataset/util/cache_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..d35617d0e4b679d50bc1e83cc2fe974590376368 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/cache_pool.h @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_CACHE_POOL_H_ +#define DATASET_UTIL_CACHE_POOL_H_ + +#include +#include +#include +#include +#include "dataset/util/allocator.h" +#include "dataset/util/service.h" +#include "dataset/util/slice.h" +#include "dataset/util/storage_manager.h" +#include "dataset/util/auto_index.h" + +namespace mindspore { +namespace dataset { +/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of +/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to +/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to +/// restore the buffer. +/// \see ReadableSlice +class CachePool : public Service { + public: + using base_type = uint8_t; + using pointer = base_type *; + using const_pointer = const base_type *; + using reference = base_type &; + using const_reference = const base_type &; + using value_allocator = Allocator; + + // An internal class to locate the whereabouts of a backed up buffer which can be either in + class DataLocator { + public: + DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} + ~DataLocator() = default; + DataLocator(const DataLocator &other) = default; + DataLocator &operator=(const DataLocator &other) = default; + DataLocator(DataLocator &&other) noexcept { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + DataLocator &operator=(DataLocator &&other) noexcept { + if (&other != this) { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + return *this; + } + pointer ptr; + size_t sz; + StorageManager::key_type storage_key; + }; + + using data_index = AutoIndexObj; + using key_type = data_index::key_type; + using bl_alloc_type = typename value_allocator::template rebind::other; + + /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and + /// how many elements are spilled to disk. + struct CacheStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + }; + + /// \brief Constructor + /// \param alloc Allocator to allocate memory from + /// \param root Optional disk folder to spill + explicit CachePool(const value_allocator &alloc, const std::string &root = ""); + + CachePool(const CachePool &) = delete; + CachePool(CachePool &&) = delete; + CachePool &operator=(const CachePool &) = delete; + CachePool &operator=(CachePool &&) = delete; + ~CachePool() noexcept; + + Status DoServiceStart() override; + Status DoServiceStop() override; + + Path GetSpillPath() const; + + /// \brief Insert a sequence of ReadableSlice objects into the pool. + /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. + /// \param[in] buf A sequence of ReadableSlice objects. + /// \param[out] key Generated key + /// \return Error code + Status Insert(const std::vector &buf, key_type *key); + /// \brief Restore a cached buffer (from memory or disk) + /// \param[in] key A previous key returned from Insert + /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice + /// \param[out] bytesRead Optional. Number of bytes read. + /// \return Error code + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; + + Status Spill(DataLocator *dl); + + Status Locate(DataLocator *dl); + + size_t GetSize(key_type key) const; + + /// \brief Get statistics. + /// \return CacheStat object + CacheStat GetStat() const; + + const value_allocator &get_allocator() const; + + std::string MyName() const { return subfolder_; } + + private: + value_allocator alloc_; + Path root_; + const std::string subfolder_; + std::shared_ptr sm_; + std::shared_ptr tree_; +}; +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/dataset/util/list.h b/mindspore/ccsrc/dataset/util/list.h index 5a08f4514e5d9f8afd2007cc359306efb10d16ad..a4c15daa0e44ed9cc40844f83f3458182695c0cd 100644 --- a/mindspore/ccsrc/dataset/util/list.h +++ b/mindspore/ccsrc/dataset/util/list.h @@ -106,6 +106,24 @@ struct List { ++count; } + // Insert elem2 before elem1 in the list. + virtual void InsertBefore(pointer elem1, pointer elem2) { + DS_ASSERT(elem1 != elem2); + Node &elem1_node = elem1->*node; + Node &elem2_node = elem2->*node; + elem2_node.next = elem1; + elem2_node.prev = elem1_node.prev; + if (elem1_node.prev != nullptr) { + Node &prev_node = elem1_node.prev->*node; + prev_node.next = elem2; + } + elem1_node.prev = elem2; + if (head == elem1) { + head = elem2; + } + ++count; + } + // Remove an element in the list virtual void Remove(pointer elem) noexcept { Node &elem_node = elem->*node; diff --git a/mindspore/ccsrc/dataset/util/memory_pool.h b/mindspore/ccsrc/dataset/util/memory_pool.h index 70876a81417141a3a517a62c0a2b30bb3e21ba80..ee1da3bda151b735d97d044d1db3977028dd9165 100644 --- a/mindspore/ccsrc/dataset/util/memory_pool.h +++ b/mindspore/ccsrc/dataset/util/memory_pool.h @@ -44,20 +44,6 @@ class MemoryPool { virtual ~MemoryPool() {} }; -// Used by unique_ptr -template -class Deleter { - public: - explicit Deleter(std::shared_ptr &mp) : mp_(mp) {} - - ~Deleter() = default; - - void operator()(T *ptr) const { mp_->Deallocate(ptr); } - - private: - std::shared_ptr mp_; -}; - Status DeMalloc(std::size_t s, void **p, bool); } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/path.cc b/mindspore/ccsrc/dataset/util/path.cc index c37fdc17f1d4b4fa5d4c347b07acdd544fee1ec6..59e5e5232c5e29dd87b2a6c3f3d8b0a51b27a74f 100644 --- a/mindspore/ccsrc/dataset/util/path.cc +++ b/mindspore/ccsrc/dataset/util/path.cc @@ -16,6 +16,8 @@ #include "dataset/util/path.h" #include +#include +#include #include #include #include @@ -26,7 +28,7 @@ namespace mindspore { namespace dataset { -#ifdef _WIN32 +#if defined(_WIN32) || defined(_WIN64) char Path::separator_ = '\\'; #else char Path::separator_ = '/'; @@ -132,7 +134,7 @@ Status Path::CreateDirectory() { #if defined(_WIN32) || defined(_WIN64) int rc = mkdir(common::SafeCStr(path_)); #else - int rc = mkdir(common::SafeCStr(path_), 0700); + int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); #endif if (rc) { std::ostringstream oss; @@ -182,6 +184,111 @@ Status Path::CreateDirectories() { return Status::OK(); } +Status Path::Remove() { + if (Exists()) { + if (IsDirectory()) { + errno_t err = rmdir(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete directory " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } else { + errno_t err = unlink(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete file " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } + } + return Status::OK(); +} + +Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } + +Status Path::OpenFile(int *file_descriptor, bool create) { + int fd; + if (file_descriptor == nullptr) { + RETURN_STATUS_UNEXPECTED("null pointer"); + } + if (IsDirectory()) { + std::ostringstream oss; + oss << "Unable to create file " << path_ << " which is a directory."; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + // Convert to canonical form. + if (strlen(common::SafeCStr(path_)) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + char canonical_path[PATH_MAX + 1] = {0x00}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { +#endif + if (errno == ENOENT && create) { + // File doesn't exist and we are to create it. Let's break it down. + auto file_part = Basename(); + auto parent_part = ParentPath(); +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { +#endif + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto cur_inx = strlen(canonical_path); + if ((cur_inx + file_part.length() + 1) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + canonical_path[cur_inx++] = separator_; + if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != + EOK) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } + if (create) { + fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); + } else { + fd = open(canonical_path, O_RDWR); + } + if (fd == -1) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + *file_descriptor = fd; + return Status::OK(); +} + +Status Path::CloseFile(int fd) const { + if (close(fd) < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + return Status::OK(); +} + +Status Path::TruncateFile(int fd) const { + int rc; + rc = ftruncate(fd, 0); + if (rc == 0) { + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } +} + +std::string Path::Basename() { + std::size_t found = path_.find_last_of(separator_); + if (found != std::string::npos) { + return path_.substr(found + 1); + } else { + return path_; + } +} + std::shared_ptr Path::DirIterator::OpenDirectory(Path *f) { auto it = new (std::nothrow) DirIterator(f); @@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() { Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; - dp_ = opendir(common::SafeCStr(f->toString())); + dp_ = opendir(f->toString().c_str()); } bool Path::DirIterator::hasNext() { @@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() { } Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } + +std::ostream &operator<<(std::ostream &os, const Path &s) { + os << s.path_; + return os; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/path.h b/mindspore/ccsrc/dataset/util/path.h index efe01a7d16cbebff7facc8c63da8e08734683e45..fbf65b8c236164533eecfa7aa2fa5cf4c087ff7d 100644 --- a/mindspore/ccsrc/dataset/util/path.h +++ b/mindspore/ccsrc/dataset/util/path.h @@ -90,6 +90,20 @@ class Path { std::string ParentPath(); + Status Remove(); + + Status CreateFile(int *fd); + + Status OpenFile(int *fd, bool create = false); + + Status CloseFile(int fd) const; + + Status TruncateFile(int fd) const; + + std::string Basename(); + + friend std::ostream &operator<<(std::ostream &os, const Path &s); + private: static char separator_; std::string path_; diff --git a/mindspore/ccsrc/dataset/util/semaphore.cc b/mindspore/ccsrc/dataset/util/semaphore.cc new file mode 100644 index 0000000000000000000000000000000000000000..36ddf5511d9316a6b822a0d05fb56dfd1f69d813 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/semaphore.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/semaphore.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +Status Semaphore::P() { + std::unique_lock lck(mutex_); + RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); + --value_; + return Status::OK(); +} +void Semaphore::V() { + std::unique_lock lck(mutex_); + ++value_; + wait_cond_.NotifyOne(); +} +int Semaphore::Peek() { + std::unique_lock lck(mutex_); + return value_; +} +Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } +Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } +void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/semaphore.h b/mindspore/ccsrc/dataset/util/semaphore.h new file mode 100644 index 0000000000000000000000000000000000000000..07b9e83e7fbc803510bcc77594cea75402049995 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/semaphore.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SEMAPHORE_H_ +#define DATASET_UTIL_SEMAPHORE_H_ + +#include "dataset/util/cond_var.h" + +namespace mindspore { +namespace dataset { +class TaskGroup; + +/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be +/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. +class Semaphore { + public: + /// \brief Constructor + /// \param init Initial value of the internal counter. + explicit Semaphore(int init) : value_(init) {} + + virtual ~Semaphore() {} + /// \brief Decrement the internal counter. Will be blocked if the value is 0. + /// \return Error code. Can get interrupt. + Status P(); + /// \brief Increment the internal counter. Wakeup on of the watiers if any. + void V(); + /// \brief Peek the internal value + /// \return The internal value + int Peek(); + Status Register(TaskGroup *vg); + Status Deregister(); + void ResetIntrpState(); + + private: + int value_; + + std::mutex mutex_; + CondVar wait_cond_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/dataset/util/slice.cc b/mindspore/ccsrc/dataset/util/slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1798b4f44a7b0153e386162f2064422e8e62e45 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/slice.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { + mutable_data_ = static_cast(src.mutable_data_) + offset; +} +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) + : WritableSlice(src, offset, src.GetSize() - offset) {} +Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { + RETURN_UNEXPECTED_IF_NULL(dest); + RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); + if (dest->GetSize() <= 0) { + RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); + } + auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); + if (err) { + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/slice.h b/mindspore/ccsrc/dataset/util/slice.h new file mode 100644 index 0000000000000000000000000000000000000000..127df23cfabaffaa650294bde65095c770779220 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/slice.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SLICE_H_ +#define DATASET_UTIL_SLICE_H_ + +#include +#include +#include +#include "./securec.h" +#include "dataset/util/allocator.h" +#include "dataset/util/status.h" +namespace mindspore { +namespace dataset { +/// \brief A ReadableSlice wraps a const pointer in memory and its size. +/// \see WritableSlice for a non-const version +/// +class ReadableSlice { + public: + ReadableSlice() : ptr_(nullptr), sz_(0) {} + ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} + ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { + ptr_ = static_cast(src.GetPointer()) + offset; + sz_ = len; + } + ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} + ReadableSlice(const ReadableSlice &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + ReadableSlice &operator=(const ReadableSlice &lhs) { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + return *this; + } + ReadableSlice(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + } + ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + return *this; + } + /// \brief Getter function + /// \return Const version of the pointer + const void *GetPointer() const { return ptr_; } + /// \brief Getter function + /// \return Size of the slice + size_t GetSize() const { return sz_; } + bool empty() const { return ptr_ == nullptr; } + + private: + const void *ptr_; + size_t sz_; +}; +/// \brief A WritableSlice inherits from ReadableSlice to allow +/// one to write to the address pointed to by the pointer. +/// +class WritableSlice : public ReadableSlice { + public: + friend class StorageContainer; + /// \brief Default constructor + WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} + /// \brief This form of a constructor takes a pointer and its size. + WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} + WritableSlice(const WritableSlice &src, off64_t offset, size_t len); + WritableSlice(const WritableSlice &src, off64_t offset); + WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } + WritableSlice &operator=(const WritableSlice &lhs) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + ReadableSlice::operator=(lhs); + } + return *this; + } + WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + } + } + WritableSlice &operator=(WritableSlice &&lhs) noexcept { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + ReadableSlice::operator=(std::move(lhs)); + } + return *this; + } + /// \brief Copy the content from one slice onto another. + static Status Copy(WritableSlice *dest, const ReadableSlice &src); + + private: + void *mutable_data_; + void *GetMutablePointer() { return mutable_data_; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/dataset/util/storage_container.cc b/mindspore/ccsrc/dataset/util/storage_container.cc new file mode 100644 index 0000000000000000000000000000000000000000..96f5b45d0cc35815cb37beafaeab189ffc672425 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_container.cc @@ -0,0 +1,164 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/storage_container.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "dataset/util/de_error.h" +#include "dataset/util/path.h" +#include "dataset/util/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +Status StorageContainer::Create() { + RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); + RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); + is_open_ = true; + MS_LOG(INFO) << "Container " << cont_ << " created"; + return Status::OK(); +} + +Status StorageContainer::Open() noexcept { + std::lock_guard lck(mutex_); + // Check again + if (!is_open_) { + RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); + is_open_ = true; + } + return Status::OK(); +} + +Status StorageContainer::Close() noexcept { + if (is_open_) { + std::lock_guard lck(mutex_); + // Check again + if (is_open_) { + RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); + is_open_ = false; + fd_ = -1; + } + } + return Status::OK(); +} + +Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { + DS_ASSERT(is_open_); + RETURN_UNEXPECTED_IF_NULL(dest); + auto sz = dest->GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pread64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = read(fd_, dest->GetMutablePointer(), sz); +#else + auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { + DS_ASSERT(is_open_); + auto sz = dest.GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pwrite64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = write(fd_, dest.GetPointer(), sz); +#else + auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Insert(const std::vector &buf, off64_t *offset) noexcept { + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + if (sz > bs_->GetMaxSize()) { + RETURN_STATUS_UNEXPECTED("Request size too big"); + } + BSpaceDescriptor bspd{0}; + addr_t addr = 0; + RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); + *offset = static_cast(addr); + // We will do piecewise copy of the data to disk. + for (auto &v : buf) { + RETURN_IF_NOT_OK(Write(v, addr)); + addr += v.GetSize(); + } + return Status::OK(); +} + +Status StorageContainer::Truncate() const noexcept { + if (is_open_) { + RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); + MS_LOG(INFO) << "Container " << cont_ << " truncated"; + } + return Status::OK(); +} + +StorageContainer::~StorageContainer() noexcept { + (void)Truncate(); + (void)Close(); +} + +std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { + os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); + return os; +} + +Status StorageContainer::CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path) { + Status rc; + auto sc = new (std::nothrow) StorageContainer(path); + if (sc == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = sc->Create(); + if (rc.IsOk()) { + (*out_sc).reset(sc); + } else { + delete sc; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_container.h b/mindspore/ccsrc/dataset/util/storage_container.h new file mode 100644 index 0000000000000000000000000000000000000000..07e41bd66a7ba7fc3080847acb4f8021cdd500c7 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_container.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ +#define DATASET_UTIL_STORAGE_CONTAINER_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/util/system_pool.h" +#include "dataset/util/buddy.h" +#include "dataset/util/path.h" +#include "dataset/util/slice.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class StorageManager; + +class StorageContainer { + public: + friend class StorageManager; + + ~StorageContainer() noexcept; + + StorageContainer(const StorageContainer &) = delete; + + StorageContainer &operator=(const StorageContainer &) = delete; + + friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); + + Status Open() noexcept; + + Status Close() noexcept; + + Status Insert(const std::vector &buf, off64_t *offset) noexcept; + + Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; + + Status Read(WritableSlice *dest, off64_t offset) const noexcept; + + Status Truncate() const noexcept; + + bool IsOpen() const { return is_open_; } + + static Status CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path); + + private: + mutable std::mutex mutex_; + Path cont_; + int fd_; + bool is_open_; + std::unique_ptr bs_; + + // Use the default value of BuddySpace + // which can map upto 4G of space. + explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} + + Status Create(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/dataset/util/storage_manager.cc b/mindspore/ccsrc/dataset/util/storage_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b7a6044e93dd44e285e50071fc9a907a93a1153 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_manager.cc @@ -0,0 +1,167 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/storage_manager.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "dataset/util/path.h" +#include "dataset/util/services.h" +#include "dataset/util//de_error.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { + std::ostringstream oss; + oss << prefix << std::setfill('0') << std::setw(5) << file_id; + return oss.str(); +} + +std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { + std::string base_name = GetBaseName(prefix, file_id); + return (base_name + "." + suffix); +} + +Status StorageManager::AddOneContainer() { + const std::string kPrefix = "IMG"; + const std::string kSuffix = "LB"; + Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); + std::shared_ptr sc; + RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); + containers_.push_back(sc); + file_id_++; + return Status::OK(); +} + +Status StorageManager::DoServiceStart() { + containers_.reserve(1000); + if (root_.IsDirectory()) { + RETURN_IF_NOT_OK(AddOneContainer()); + } else { + RETURN_STATUS_UNEXPECTED("Not a directory"); + } + return Status::OK(); +} + +Status StorageManager::Write(key_type *key, const std::vector &buf) { + RETURN_UNEXPECTED_IF_NULL(key); + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + std::shared_ptr cont; + key_type out_key; + value_type out_value; + bool create_new_container = false; + do { + SharedLock lock_s(&rw_lock_); + size_t num_containers = containers_.size(); + if (create_new_container) { + // Upgrade to exclusvie lock. + lock_s.Upgrade(); + create_new_container = false; + // Check again if someone has already added a + // new container after we got the x lock + if (containers_.size() == num_containers) { + RETURN_IF_NOT_OK(AddOneContainer()); + } + // Refresh how many containers there are. + num_containers = containers_.size(); + // Downgrade back to shared lock + lock_s.Downgrade(); + } + if (num_containers == 0) { + RETURN_STATUS_UNEXPECTED("num_containers is zero"); + } + // Go to the last container to insert. + cont = containers_.at(num_containers - 1); + off64_t offset; + Status rc = cont->Insert(buf, &offset); + if (rc.IsNoSpace()) { + create_new_container = true; + } else if (rc.IsOk()) { + out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); + RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); + *key = out_key; + break; + } else { + return rc; + } + } while (true); + return Status::OK(); +} + +Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = index_.Search(key); + if (r.second) { + auto &it = r.first; + value_type v = *it; + int container_inx = v.first; + off_t offset = v.second.first; + size_t sz = v.second.second; + if (dest->GetSize() < sz) { + std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + + " but length = " + std::to_string(dest->GetSize()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + if (bytesRead != nullptr) { + *bytesRead = sz; + } + auto cont = containers_.at(container_inx); + RETURN_IF_NOT_OK(cont->Read(dest, offset)); + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} + +Status StorageManager::DoServiceStop() noexcept { + Status rc; + Status rc1; + for (auto const &p : containers_) { + // The destructor of StorageContainer is not called automatically until the use + // count drops to 0. But it is not always the case. We will do it ourselves. + rc = p.get()->Truncate(); + if (rc.IsError()) { + rc1 = rc; + } + } + containers_.clear(); + file_id_ = 0; + return rc1; +} + +StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} + +StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } + +std::ostream &operator<<(std::ostream &os, const StorageManager &s) { + os << "Dumping all containers ..." + << "\n"; + for (auto const &p : s.containers_) { + os << *(p.get()); + } + return os; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_manager.h b/mindspore/ccsrc/dataset/util/storage_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..075ac713d2c98c4f1e281738c3a1fda35edda4ea --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_manager.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ +#define DATASET_UTIL_STORAGE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "dataset/util/allocator.h" +#include "dataset/util/auto_index.h" +#include "dataset/util/lock.h" +#include "dataset/util/memory_pool.h" +#include "dataset/util/path.h" +#include "dataset/util/service.h" +#include "dataset/util/slice.h" +#include "dataset/util/storage_container.h" + +using ListOfContainers = std::vector>; +namespace mindspore { +namespace dataset { +class StorageManager : public Service { + public: + using storage_index = AutoIndexObj>>; + using key_type = storage_index::key_type; + using value_type = storage_index::value_type; + + explicit StorageManager(const Path &); + + ~StorageManager() override; + + StorageManager(const StorageManager &) = delete; + + StorageManager &operator=(const StorageManager &) = delete; + + Status Write(key_type *out_key, const std::vector &buf); + + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; + + Status DoServiceStart() override; + + Status DoServiceStop() noexcept override; + + friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); + + private: + Path root_; + ListOfContainers containers_; + int file_id_; + RWLock rw_lock_; + storage_index index_; + + std::string GetBaseName(const std::string &prefix, int32_t file_id); + + std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); + + Status AddOneContainer(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/system_pool.h b/mindspore/ccsrc/dataset/util/system_pool.h index bd15ad11ddf977ae30318c12904974c3aedd6e90..286e30a615815121f3c2983c8150fc1f3b177a0b 100644 --- a/mindspore/ccsrc/dataset/util/system_pool.h +++ b/mindspore/ccsrc/dataset/util/system_pool.h @@ -19,8 +19,10 @@ #include #include #include +#include #include #include "./securec.h" +#include "dataset/util/allocator.h" #include "dataset/util/memory_pool.h" namespace mindspore { @@ -61,6 +63,11 @@ class SystemPool : public MemoryPool { uint64_t get_max_size() const override { return std::numeric_limits::max(); } int PercentFree() const override { return 100; } + + template + static Allocator GetAllocator() { + return Allocator(std::make_shared()); + } }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 4581141790a8d58b58831e0ed16bb6619d76e168..dce11afcea60a9312c060591d9935d500caaaa1f 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -30,6 +30,7 @@ #include "kernel/common_utils.h" #include "kernel/oplib/oplib.h" #include "ir/value.h" +#include "pre_activate/common/helper.h" using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; @@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { } } -void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, +void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, AddressPtrList *kernel_outputs) { MS_EXCEPTION_IF_NULL(kernel); @@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { return GenAddrCleanLaunchArgs(cnode, kernel_inputs); } + auto is_all_nop_node = opt::IsAllNopNode(&graph); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); + DeviceAddressPtr device_address; + if (is_all_nop_node) { + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false); + } else { + device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true); + } MS_EXCEPTION_IF_NULL(device_address); kernel::AddressPtr input = std::make_shared(); MS_EXCEPTION_IF_NULL(input); @@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod kernel_inputs->emplace_back(input); } - for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetOutputAddr(kernel, i); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) { + DeviceAddressPtr device_address; + if (is_all_nop_node) { + device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + } else { + device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true); + } + MS_EXCEPTION_IF_NULL(device_address); kernel::AddressPtr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); output->addr = device_address->ptr_; @@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod kernel_outputs->emplace_back(output); } - for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); kernel::AddressPtr workspace = std::make_shared(); MS_EXCEPTION_IF_NULL(workspace); @@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; - GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); if (!ret) { MS_LOG(ERROR) << "Launch kernel failed."; diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 8442342e32200a2ea1868c3b57624bbc45f3e28e..c69487c6f175a1d807ea9a23b82917838cc96c1a 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -96,8 +96,8 @@ class KernelRuntime { private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); - void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); + void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, + AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); diff --git a/mindspore/ccsrc/ir/optimizer_caller.h b/mindspore/ccsrc/ir/optimizer_caller.h index bd3045414737fd970e932729c305938b1269795f..036f4ab5109a4e645d0aad37bd280198c7fcdbe1 100644 --- a/mindspore/ccsrc/ir/optimizer_caller.h +++ b/mindspore/ccsrc/ir/optimizer_caller.h @@ -17,13 +17,23 @@ #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ +#include + #include "ir/anf.h" -#include "optimizer/opt.h" namespace mindspore { +namespace opt { +class Optimizer; +using OptimizerPtr = std::shared_ptr; +using OptimizerWeakPtr = std::weak_ptr; + +using PredicateFuncType = std::function; +} // namespace opt + class OptimizerCaller { public: virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } }; +using OptimizerCallerPtr = std::shared_ptr; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 5eda8479170a494756ca37fd28c2995935f4707b..4a8ae81afa4824cf36bdc1be5ce06c9df850608e 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -23,6 +23,7 @@ #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "kernel/akg/akg_kernel_metadata.h" #include "session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace kernel { @@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorenable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { + kernel_type = KernelType::AKG_KERNEL; + } + switch (kernel_type) { case KernelType::AKG_KERNEL: AkgMetadataInfo(kernel_node, kernel_info_list); diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index 1af08ea3e127f781781f798946f3b9a8c52ee1be..0b675cca7211687289ca4962057782c226c39308 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } - +// The op like print, summary, or the op do not has true output, and always as a depend node input. +static bool HasSideEffect(const AnfNodePtr &node) { + auto prim = GetCNodePrimitive(node); + if (prim == nullptr) { + return false; + } + auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); + if (side_effect_v != nullptr && side_effect_v->isa()) { + return GetValue(side_effect_v); + } + return false; +} +// If true do not merge the node. bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { bool has_random_effect = false; auto prim_main = GetCNodePrimitive(main); auto prim_node = GetCNodePrimitive(node); - if (prim_main == prim_node) { - return false; - } + // if has random effect, when generate by different op (not same object), do not merge. if (prim_main != nullptr) { + if (prim_main == prim_node) { + return false; + } auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); if (effect_val != nullptr && effect_val->isa()) { has_random_effect = GetValue(effect_val); @@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons return has_random_effect; } -bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { +bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); - bool replace = false; if (main->isa() && node->isa()) { auto main_value = GetValueNode(main); auto node_value = GetValueNode(node); - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); } else if (main->isa() && node->isa()) { auto c_main = main->cast(); auto c_node = node->cast(); + // When appsame is true, check if has side effect, do not merge. + if (check_side_effect && HasSideEffect(main)) { + return false; + } const auto &inp1 = c_main->inputs(); const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - // Handle the case of two different Tensor, but with the same value - if (IsValueNode(inp1[j]) && IsValueNode(inp2[j])) { - auto tensor1 = GetValueNode(inp1[j]); - auto tensor2 = GetValueNode(inp2[j]); - if (tensor1->ValueEqual(*tensor2)) { - continue; - } + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + // Handle the case of two different Tensor, but with the same value + if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { + auto tensor1 = GetValueNode(inp1_j); + auto tensor2 = GetValueNode(inp2_j); + if (tensor1->ValueEqual(*tensor2)) { + continue; + } + } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { + // When the same side effect node as another two nodes' inputs, we still merge the node. + // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the + // node. + if (CheckReplace(inp1_j, inp2_j, false)) { + continue; } - appsame = false; - break; } + return false; } - if (CheckRandomEffect(c_main, c_node)) { - appsame = false; - } - replace = appsame; } + // When appsame is true, check if has random effect do not merge + if (CheckRandomEffect(c_main, c_node)) { + return false; + } + return true; } - return replace; + // a parameter node. + return false; } bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index fd90f61eebc634258f82ca861cf329466d91211a..57163cc5c9d8430dd2e077e763d830997f074eab 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -41,7 +41,7 @@ class CSE { return chg && report_changes_; } - virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; + virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 0033e386d8a2b2a488473d0d25ca37e18b1419ac..0996abee2c2612d4423b532e36dcd5c5caa17143 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -14,140 +14,154 @@ * limitations under the License. */ -#include "optimizer/irpass.h" - #include -#include "optimizer/irpass/symbol_resolver.h" +#include "optimizer/irpass.h" #include "optimizer/irpass/arithmetic_simplify.h" -#include "optimizer/irpass/special_op_eliminate.h" -#include "optimizer/irpass/item_tuple_eliminate.h" -#include "optimizer/irpass/env_item_eliminate.h" -#include "optimizer/irpass/tile_eliminate.h" -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass/reshape_eliminate.h" -#include "optimizer/irpass/transpose_eliminate.h" -#include "optimizer/irpass/reduce_eliminate.h" -#include "optimizer/irpass/partial_eliminate.h" -#include "optimizer/irpass/ref_eliminate.h" -#include "optimizer/irpass/merge_addn.h" #include "optimizer/irpass/branch_culling.h" +#include "optimizer/irpass/cast_eliminate.h" +#include "optimizer/irpass/convert.h" +#include "optimizer/irpass/env_item_eliminate.h" +#include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/gradient_eliminate.h" -#include "optimizer/irpass/minmax_grad.h" #include "optimizer/irpass/inline.h" -#include "optimizer/irpass/convert.h" -#include "optimizer/irpass/specialize_transform.h" -#include "optimizer/irpass/incorporate_getitem.h" #include "optimizer/irpass/incorporate_call.h" -#include "optimizer/irpass/grad_var_prepare.h" -#include "optimizer/irpass/param_replace.h" +#include "optimizer/irpass/incorporate_getitem.h" +#include "optimizer/irpass/item_tuple_eliminate.h" #include "optimizer/irpass/mark_interface_fusion.h" +#include "optimizer/irpass/merge_addn.h" +#include "optimizer/irpass/minmax_grad.h" +#include "optimizer/irpass/param_replace.h" +#include "optimizer/irpass/partial_eliminate.h" +#include "optimizer/irpass/reduce_eliminate.h" +#include "optimizer/irpass/ref_eliminate.h" +#include "optimizer/irpass/reshape_eliminate.h" +#include "optimizer/irpass/special_op_eliminate.h" +#include "optimizer/irpass/specialize_transform.h" +#include "optimizer/irpass/symbol_resolver.h" +#include "optimizer/irpass/tile_eliminate.h" +#include "optimizer/irpass/transpose_eliminate.h" #include "optimizer/opt.h" namespace mindspore { namespace opt { namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { - arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", + arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); - arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); + arithmetic_simplify2_ = + MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); special_op_eliminate_ = - MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", + MakeSubstitution(std::make_shared(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); - adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); + zero_like_fill_zero_ = + MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); + adjust_all_reduce_mul_add_ = + MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate - item_tuple_eliminate_ = - MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); - tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); - cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); - reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); - transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); + item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); + cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); + reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); + transpose_eliminate_ = + MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); reduce_eliminate_ = MakeSubstitution( - ReduceOneEliminater(), "reduce_eliminate", + std::make_shared(), "reduce_eliminate", {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); - partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); - same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); - check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); - reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode); - depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); + partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); + same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); + check_bprop_eliminate_ = + MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); + reset_defer_inline_ = + MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); + depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); // Env Item Eliminate - env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); + env_get_item_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); + new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_ = - MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_switch_ = - MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), + "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); // Ref eliminate - make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); - get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", + make_ref_eliminate_ = + MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); + get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", + get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - replace_refkey_by_param_ = - MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); - replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); + replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", + IsValueNode, opt::FORCE_RENORM); + replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); // Gradient transforms - expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); - minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); + expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); + minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); // branch culling - switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); - float_tuple_getitem_switch_ = - MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); + switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); + float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), + "float_tuple_getitem_switch", prim::kPrimTupleGetItem); float_env_getitem_switch_ = - MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); - convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); + MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); + convert_switch_replacement_ = + MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); // Addn - merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); - addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); + merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); + addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); // inline - inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); - replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode); - specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); + inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + replace_applicator_ = + MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); + specialize_transform_ = + MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); // Incorporation incorporate_getitem_set_ = - MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); - incorporate_getitem_from_param_ = - MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); - incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); - incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); + MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); + incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), + "incorporate_getitem_from_param", IsCNodeGraphKernel); + incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); + incorporate_call_switch_ = + MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); // Virtual Dataset - virtual_dataset_eliminate_ = - MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); + virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), + "virtual_dataset_eliminate", prim::kPrimVirtualDataset); // Convert - print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); + print_tuple_wrapper_ = + MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); // Unused parameter eliminate unused_parameter_eliminate_ = - MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); - unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); + MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); + unused_output_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); // AddN eliminate - addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); + addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); // Mark interface fusion - mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); + mark_interface_fusion_ = + MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); } ResolveIRPassLib::ResolveIRPassLib() { - resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); - resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); + resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); + resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } InferenceOptPrepareLib::InferenceOptPrepareLib() { - grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); + grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); } } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 270db8305f65c0ce3acb8268eba9dfd071e337a0..a26b81e95298ea6651bcd3b0fd26050eb46be467 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -17,15 +17,16 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ -#include -#include #include +#include +#include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" +#include "ir/optimizer_caller.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { @@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { FuncGraphPtr all_reduce_fg_{nullptr}; }; -class ArithmeticSimplify { +class ArithmeticSimplify : public OptimizerCaller { public: ArithmeticSimplify() - : multiply_by_zero_or_one_(), - tensor_multiply_by_one_(), - add_by_zero_(), - tensor_add_by_zero_(), - identity_(prim::kPrimIdentity), - opt_update_zero_tensor_(), - constant_duplicate_mul_(), - power_one_() { + : multiply_by_zero_or_one_(std::make_shared()), + tensor_multiply_by_one_(std::make_shared()), + add_by_zero_(std::make_shared()), + tensor_add_by_zero_(std::make_shared()), + identity_(std::make_shared(prim::kPrimIdentity)), + opt_update_zero_tensor_(std::make_shared()), + constant_duplicate_mul_(std::make_shared()), + power_one_(std::make_shared()) { eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(tensor_multiply_by_one_); eliminaters_.emplace_back(add_by_zero_); @@ -761,10 +762,10 @@ class ArithmeticSimplify { } ~ArithmeticSimplify() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -773,15 +774,9 @@ class ArithmeticSimplify { } private: - MultiplyByZeroOrOne multiply_by_zero_or_one_; - TensorMultiplyByOne tensor_multiply_by_one_; - AddByZero add_by_zero_; - TensorAddByZero tensor_add_by_zero_; - PrimEliminater identity_; - OptUpdateZeroTensor opt_update_zero_tensor_; - ConstantDuplicateMul constant_duplicate_mul_; - PowerOneEliminate power_one_; - std::vector eliminaters_{}; + OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, + opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; + std::vector eliminaters_{}; }; // Arithmetic Simplifications should be done after step_parallel. @@ -789,15 +784,17 @@ class ArithmeticSimplify { // with shape(weight), but after step_parallel, shape of weight may be changed, so the // shape of the constant tensor should also be changed. So this pass is seperated from // ArithmeticSimplify and deferred until step_parallel. -class ArithmeticSimplify2 { +class ArithmeticSimplify2 : public OptimizerCaller { public: - ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } + ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { + eliminaters_.emplace_back(tensor_multiply_by_zero_); + } ~ArithmeticSimplify2() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -806,8 +803,8 @@ class ArithmeticSimplify2 { } private: - TensorMultiplyByZero tensor_multiply_by_zero_; - std::vector eliminaters_{}; + OptimizerCallerPtr tensor_multiply_by_zero_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h index 734d88cb10f5a39906ec0493f220479a630fb8a6..d98d0b677b3c0fc2e63d73363c6f6fdab12e71a4 100644 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#include "ir/visitor.h" #include "optimizer/irpass.h" #include "optimizer/optimizer.h" -#include "ir/visitor.h" namespace mindspore { namespace opt { @@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}, t_{nullptr}; }; -class CastEliminater { +class CastEliminater : public OptimizerCaller { public: CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} ~CastEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto new_node = cast_same_type_eliminater_(optimizer, node); if (new_node != nullptr) { return new_node; diff --git a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h index 0f59c69fef8a6c8bf855f2cec463d5af19b34cf5..3f100dcaec3c330de99e161faa58dfea5d4f6acf 100644 --- a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h @@ -17,18 +17,19 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#include -#include #include -#include #include +#include +#include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" #include "utils/symbolic.h" namespace mindspore { @@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { bool is_match_{false}; }; -class EnvGetItemEliminater { +class EnvGetItemEliminater : public OptimizerCaller { public: - EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { + EnvGetItemEliminater() + : new_env_get_item_(std::make_shared()), + add_env_get_item_(std::make_shared()), + env_get_set_item_(std::make_shared()) { eliminaters_.emplace_back(new_env_get_item_); eliminaters_.emplace_back(add_env_get_item_); eliminaters_.emplace_back(env_get_set_item_); } ~EnvGetItemEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -246,10 +250,8 @@ class EnvGetItemEliminater { } private: - NewEnvGetItem new_env_get_item_; - AddEnvGetItem add_env_get_item_; - EnvGetSetItem env_get_set_item_; - std::vector eliminaters_{}; + OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; + std::vector eliminaters_{}; }; // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h index 5afee45e95f3d92f2f2ae607d876aea70c0d08fa..b6c8fb0e18e9b545189496dfa7db9c68a1b2352f 100644 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h @@ -17,18 +17,20 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ -#include #include -#include #include +#include #include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" + namespace mindspore { namespace opt { namespace irpass { @@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { internal::GetitemTransform getitem_transform_; }; -class IncorporateGetitemSet { +class IncorporateGetitemSet : public OptimizerCaller { public: - IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { + IncorporateGetitemSet() + : incorporate_getitem_(std::make_shared()), + incorporate_getitem_switch_(std::make_shared()) { eliminaters_.emplace_back(incorporate_getitem_); eliminaters_.emplace_back(incorporate_getitem_switch_); } ~IncorporateGetitemSet() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -403,9 +407,8 @@ class IncorporateGetitemSet { } private: - IncorporateGetitem incorporate_getitem_; - IncorporateGetitemSwitch incorporate_getitem_switch_; - std::vector eliminaters_{}; + OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h index 21cdff51ad02d687d68b7d7e5c18e5ef6e161538..202951a254113d586eae46b4b4ad89509d76c827 100644 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h @@ -17,13 +17,15 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#include #include +#include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" +#include "ir/optimizer_caller.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { @@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; }; -class ItemTupleEliminater { +class ItemTupleEliminater : public OptimizerCaller { public: ItemTupleEliminater() - : get_item_eliminater_(), - get_item_const_eliminater_(), - set_item_eliminater_(), - get_set_item_eliminater_(), - get_item_depend_reorder_() { + : get_item_eliminater_(std::make_shared()), + get_item_const_eliminater_(std::make_shared()), + set_item_eliminater_(std::make_shared()), + get_set_item_eliminater_(std::make_shared()), + get_item_depend_reorder_(std::make_shared()) { eliminaters_.emplace_back(get_item_eliminater_); eliminaters_.emplace_back(get_item_const_eliminater_); eliminaters_.emplace_back(set_item_eliminater_); @@ -277,10 +279,10 @@ class ItemTupleEliminater { } ~ItemTupleEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -289,12 +291,9 @@ class ItemTupleEliminater { } private: - GetitemEliminater get_item_eliminater_; - GetitemConstEliminater get_item_const_eliminater_; - SetitemEliminater set_item_eliminater_; - GetSetitemEliminater get_set_item_eliminater_; - GetitemDependReorder get_item_depend_reorder_; - std::vector eliminaters_{}; + OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, + get_item_depend_reorder_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h index 41f379221c61aae8a2de07f229cf09612aa5920b..6d81b401c3c928b87f698a7ce955239a7d59f1a3 100644 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h @@ -19,9 +19,9 @@ #include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" #include "ir/pattern_matcher.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h index fb43f6ffd8acd6eaef363789e9ba21591638e1e1..cafc8b796c4a08cf021a638bc9cebf7b10cbee38 100644 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h @@ -19,11 +19,12 @@ #include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" #include "pipeline/static_analysis/dshape.h" namespace mindspore { @@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}, shape_{nullptr}; }; -class ReshapeEliminater { +class ReshapeEliminater : public OptimizerCaller { public: ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} ~ReshapeEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto new_node = reshape_same_shape_eliminater_(optimizer, node); if (new_node != nullptr) { return new_node; diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index dcba80431ade397dbbb45203ce1647cab091caa8..b6a4e1c85238a09d13f14c2a63432ce64ddbdb1c 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -18,31 +18,31 @@ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ #include -#include -#include #include +#include +#include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" #include "ir/optimizer_caller.h" -#include "optimizer/irpass/prim_eliminate.h" +#include "ir/pattern_matcher.h" #include "ir/visitor.h" #include "operator/ops.h" -#include "ir/pattern_matcher.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -class SpecialOpEliminater { +class SpecialOpEliminater : public OptimizerCaller { public: SpecialOpEliminater() - : insert_gradient_of_(prim::kPrimInsertGradientOf), - stop_gradient_(prim::kPrimStopGradient), - hook_backward_(prim::kPrimHookBackward), - print_shape_type_(prim::kPrimPrintShapeType), - get_ref_value_(prim::kPrimGetRefValue), - mirror_(prim::kPrimMirror), - virtual_div_(prim::kPrimVirtualDiv) { + : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), + stop_gradient_(std::make_shared(prim::kPrimStopGradient)), + hook_backward_(std::make_shared(prim::kPrimHookBackward)), + print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), + get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), + mirror_(std::make_shared(prim::kPrimMirror)), + virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { eliminaters_.emplace_back(insert_gradient_of_); eliminaters_.emplace_back(stop_gradient_); eliminaters_.emplace_back(hook_backward_); @@ -53,10 +53,10 @@ class SpecialOpEliminater { } ~SpecialOpEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -65,9 +65,9 @@ class SpecialOpEliminater { } private: - PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, + OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; - std::vector eliminaters_{}; + std::vector eliminaters_{}; }; // {PrimVirtualDataset, X} -> X diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 82fbcc2036bbe60a60a5db4af13275a81286bf34..4c2e85157f02d314180bd54b42dfd87c1f44959f 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -16,28 +16,27 @@ #include "optimizer/opt.h" +#include +#include #include #include -#include -#include #include "ir/anf.h" #include "ir/manager.h" -#include "utils/ordered_set.h" - -#include "utils/log_adapter.h" #include "optimizer/optimizer.h" +#include "utils/log_adapter.h" +#include "utils/ordered_set.h" namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &renorm_action) { auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &renorm_action) { auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { @@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } -AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double t = GetTime(); #endif - AnfNodePtr result = transform_(optimizer, node); + AnfNodePtr result = (*transform_)(optimizer, node); #ifdef ENABLE_PROFILE if (optimizer != nullptr) { auto time = GetTime(); diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h index fb0bdc58be9be63c316cae5179284c4c8c084300..6601d969d28f714eef25777a08f59e2c0bf109d5 100644 --- a/mindspore/ccsrc/optimizer/opt.h +++ b/mindspore/ccsrc/optimizer/opt.h @@ -17,24 +17,18 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ -#include -#include #include +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" +#include "ir/optimizer_caller.h" #include "operator/ops.h" namespace mindspore { /* namespace to support opt */ namespace opt { -class Optimizer; - -using OptimizerPtr = std::shared_ptr; -using OptimizerWeakPtr = std::weak_ptr; - -using PredicateFuncType = std::function; -using TransformFuncType = std::function; // Define the interaction mode between an Optimize pass and Renormalize pass // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed @@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; class Substitution { public: - TransformFuncType transform_{nullptr}; + OptimizerCallerPtr transform_; std::string name_; PredicateFuncType predicate_{nullptr}; // an enum to mark this Substitution relation to renormalize pass RenormAction renorm_action_; - Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, + Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} ~Substitution() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); }; using SubstitutionPtr = std::shared_ptr; -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); class SubstitutionList { diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index cbf5bc40c8fc8923218b09e341b1933131ffd830..8ebfdb7d1309f34b506f38b9a1cf51cf5ce3fca1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector &inputs, co CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - TensorRedistribution tensor_redistribution; + TensorRedistribution tensor_redistribution(false, true); if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; } @@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector &inp CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - TensorRedistribution tensor_redistribution; + TensorRedistribution tensor_redistribution(false, true); if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; } diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index 8957dc842c4a0c15f77cbb8a8b6d4c970097da42..062d814aa040b887f38b5f5579b261907f8c1565 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -62,6 +62,7 @@ void ParallelContext::Reset() { enable_all_reduce_fusion_ = false; strategy_ckpt_load_file_ = ""; strategy_ckpt_save_file_ = ""; + enable_parallel_optimizer_ = false; } void ParallelContext::set_device_num(int32_t device_num) { diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index efa528d1793a12395c70d2a5eb9b1d78bebe0ba4..6a503ca7eda609ab820cd64eaa6975890b1d66c0 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -100,6 +100,11 @@ class ParallelContext { void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } + void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { + enable_parallel_optimizer_ = enable_parallel_optimizer; + } + bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } + void Reset(); private: @@ -123,6 +128,7 @@ class ParallelContext { std::map> all_reduce_fusion_split_sizes_; std::string strategy_ckpt_load_file_; std::string strategy_ckpt_save_file_; + bool enable_parallel_optimizer_; }; void ParallelParameterContextInit(const FuncGraphPtr &func_graph); diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 7025447a29c0138fe5bd9e42343fbd18893efb12..dc309808d9ac3f69d0630c0fc25002380011550d 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") + .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, + "Set enable/disable parallel optimizer.") + .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, + "Get enable/disable parallel optimizer.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc index 9af50eac330910316d797057cc982c9364d24386..297a167aa8edeed4a6c96cbb9f9ceff1dbed870e 100644 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc @@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { } } // namespace -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h index 8e1768ea99c6c92bfa3df77af02b3d830bad36f1..18f433ab955378cf148c55a8ade47b255007f506 100644 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h +++ b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h @@ -31,7 +31,7 @@ class BackendCSE : public CSE { public: BackendCSE() = default; ~BackendCSE() override = default; - bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override; + bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 83392784f3e1d059f52b09cd0b37e522087a282d..253e271e52595b80009611ded37b1f33d384d9e4 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; +const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 74c27ff35d20cecf6227db8fa4cb48be6d6c4213..6ea584e66d87a2c4ecf1fcedd2cf32c5506f6756 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[]; - +extern const char GRAPH_FLAG_SIDE_EFFECT[]; } // namespace mindspore #endif // PYBIND_API_EXPORT_FLAGS_H_ diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 868b968d9e4cee625c29c3635c14cd3b9e3ddd76..561073b2c22f1967deeb843d03d42e22243f6be6 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; namespace mindspore { namespace session { +static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { + auto &nodes = parent_graph->execution_order(); + for (auto &node : nodes) { + if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { + return node; + } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && + (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || + child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { + return node; + } + } + MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); + return nullptr; +} + static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, const NotNull *> memo) { if (memo->find(kg.get()) != memo->end()) { @@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapsecond), NOT_NULL(arg), NOT_NULL(parameter)); + InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), + NOT_NULL(parameter)); } } } @@ -263,7 +279,7 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNullSetExecOrderByDefault(); MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); return NOT_NULL(start_label); } @@ -433,7 +449,8 @@ std::tuple AscendControlParser::ParsePartial(NotNull kg, NotNull from, +void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, + NotNull to_graph, NotNull from, NotNull to) { std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); @@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull kg << to_outputs.size() << "]"; } for (size_t i = 0; i < from_outputs.size(); i++) { - InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + if (assign_node != nullptr) { + auto jump_node = GetJumpNode(from_graph, to_graph); + if (jump_node != nullptr) { + InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); + } + } } } -void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, - NotNull to) { +AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, + NotNull to) { if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return; + return nullptr; } if (from.get() == to.get()) { - return; + return nullptr; } MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " << to->DebugString(); @@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul assign_node->set_abstract(to->abstract()); // append the assign at the end of from graph InsertDependToGraph(kg, NOT_NULL(assign_node)); + return assign_node; } std::vector AscendControlParser::RecurseGraph(NotNull graph, diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 73d68449b31f003dbf2ac57bf27f245af3319c8c..0cf7069046d49e153e592a99792273a13fe7d3db 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -52,8 +52,9 @@ class AscendControlParser { const CNodePtr &last_label); static std::tuple ParsePartial(NotNull node); - static void InsertMultipleAssignToGraph(NotNull kg, NotNull from, NotNull to); - static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, + NotNull from, NotNull to); + static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); // root graph order static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 8a4982cd4f3346025b763d5f0ebc656ea71332a3..f9132ff2d09b2a74f2d99f38594a296c67db5275 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -521,6 +521,47 @@ std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { return output_nodes; } +// Find control_depend real input nodes. +void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(result); + MS_EXCEPTION_IF_NULL(visited); + if (visited->find(anf_node) != visited->end()) { + MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; + return; + } + visited->insert(anf_node); + if (AnfAlgo::IsRealKernel(anf_node)) { + result->emplace_back(anf_node); + return; + } + if (!anf_node->isa()) { + return; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); + } + auto input0 = cnode->input(0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + GetAllFatherRealNode(cnode->input(i), result, visited); + } + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); + } else if (IsPrimitive(input0, prim::kPrimDepend)) { + if (cnode->inputs().size() != kDependInputSize) { + MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); + GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); + } +} + // update the depend relations of control depend void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { for (const auto &node : depends) { @@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de if (depend_node->isa() && depend_mode == 1) { depend_nodes = GetOutputNodes(depend_node); } - for (auto &first_node : prior_nodes) { + + std::vector real_prior_nodes; + std::set prior_visited; + for (const auto &tmp : prior_nodes) { + GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + } + + std::vector real_depend_nodes; + std::set depend_visited; + for (const auto &tmp : depend_nodes) { + GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); + } + + for (auto &first_node : real_prior_nodes) { if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { continue; } - for (auto &second_node : depend_nodes) { + for (auto &second_node : real_depend_nodes) { if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { continue; } diff --git a/mindspore/ccsrc/session/session.cc b/mindspore/ccsrc/session/session.cc index 90e02b37ff18a566d50fd87a78e425969677e384..ae70fc77aa5324d653a670260d8bd46f72ae0a6f 100644 --- a/mindspore/ccsrc/session/session.cc +++ b/mindspore/ccsrc/session/session.cc @@ -33,9 +33,14 @@ namespace py = pybind11; namespace mindspore::inference { std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { - inference::Session::RegAllOp(); - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); - return anf_graph; + try { + inference::Session::RegAllOp(); + auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed"; + return nullptr; + } } void ExitInference() { @@ -51,12 +56,17 @@ void ExitInference() { } std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { - auto session = std::make_shared(); - auto ret = session->Init(device, device_id); - if (ret != 0) { + try { + auto session = std::make_shared(); + auto ret = session->Init(device, device_id); + if (ret != 0) { + return nullptr; + } + return session; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CreatSession failed"; return nullptr; } - return session; } void Session::RegAllOp() { @@ -113,47 +123,71 @@ void Session::RegAllOp() { uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { MS_ASSERT(session_impl_ != nullptr); - auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); - py::gil_scoped_release gil_release; - return graph_id; + try { + auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + py::gil_scoped_release gil_release; + return graph_id; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CompileGraph failed"; + return static_cast(-1); + } } MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { - std::vector inTensors; - inTensors.resize(inputs.size()); - bool has_error = false; - std::transform(inputs.begin(), inputs.end(), inTensors.begin(), - [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; - has_error = true; - return nullptr; - } - auto tensor = static_cast(tensor_ptr.get()); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; - has_error = true; - return nullptr; - } - return tensor->tensor(); - }); - if (has_error) { - MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; - std::vector> multiTensor; - return multiTensor; - } - VectorRef outputs; - session_impl_->RunGraph(graph_id, inTensors, &outputs); + try { + std::vector inTensors; + inTensors.resize(inputs.size()); + bool has_error = false; + std::transform(inputs.begin(), inputs.end(), inTensors.begin(), + [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; + has_error = true; + return nullptr; + } + auto tensor = static_cast(tensor_ptr.get()); + if (tensor == nullptr) { + MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; + has_error = true; + return nullptr; + } + return tensor->tensor(); + }); + if (has_error) { + MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; + std::vector> multiTensor; + return multiTensor; + } + VectorRef outputs; + session_impl_->RunGraph(graph_id, inTensors, &outputs); - return TransformVectorRefToMultiTensor(outputs); + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference Rungraph failed"; + return MultiTensor(); + } } - +namespace { +string AjustTargetName(const std::string &device) { + if (device == kAscendDevice) { + return std::string(kAscendDevice) + "Inference"; + } else { + MS_LOG(ERROR) << "Only support device Ascend right now"; + return ""; + } +} +} // namespace int Session::Init(const std::string &device, uint32_t device_id) { RegAllOp(); auto ms_context = MsContext::GetInstance(); ms_context->set_execution_mode(kGraphMode); - ms_context->set_device_target(kAscendDevice); - session_impl_ = session::SessionFactory::Get().Create(device); + ms_context->set_device_id(device_id); + auto ajust_device = AjustTargetName(device); + if (ajust_device == "") { + return -1; + } + ms_context->set_device_target(device); + session_impl_ = session::SessionFactory::Get().Create(ajust_device); if (session_impl_ == nullptr) { MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; return -1; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index e080c862853f35e2bc05815410fccd67a3955229..a01d4f205b58e191164b410b8788c03b95f61760 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne } } // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) - auto address = AnfAlgo::GetOutputAddr(node, output_index); + DeviceAddressPtr address; + auto is_all_nop_node = opt::IsAllNopNode(&graph); + if (is_all_nop_node) { + // The graph does not remove the nop node. + address = AnfAlgo::GetMutableOutputAddr(node, output_index, false); + } else { + // The graph removes the nop node. + address = AnfAlgo::GetMutableOutputAddr(node, output_index, true); + } MS_EXCEPTION_IF_NULL(address); auto shape = AnfAlgo::GetOutputInferShape(node, output_index); TypeId type_id = kNumberTypeFloat32; @@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { - tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); + tensor->set_device_address(address); tensor->set_dirty(false); } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 32333a06ae58ea5ab1547c6cbfc90d9ac03ba930..3f6b31303c39cf261c3f4848d9bc2c0d0c4d67b9 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); } if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(WARNING) << "Control depend node's src or dest node is not a apply node, ignore it"; + MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; error_ = SUCCESS; } return true; @@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { }); } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); + } else if (src_ops_list->empty() || dst_ops_list->empty()) { + MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; } else { MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() << " -> dst:" << dst_ops_list->size(); diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index d16fbead9bc1ddd6ef1ab5acc1a1ade99050e400..3588754dae18d0ebef730463aef1ed3b1192abc1 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -463,7 +463,7 @@ void InitSubModulesLogLevel() { // set submodule's log level auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); - MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; + MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; LogConfigParser parser(submodule); auto configs = parser.Parse(); for (const auto &cfg : configs) { @@ -489,22 +489,14 @@ void InitSubModulesLogLevel() { } // namespace mindspore extern "C" { -// shared lib init hook #if defined(_WIN32) || defined(_WIN64) -__attribute__((constructor)) void mindspore_log_init(void) { +__attribute__((constructor)) void common_log_init(void) { #else -void mindspore_log_init(void) { +void common_log_init(void) { #endif #ifdef USE_GLOG // do not use glog predefined log prefix FLAGS_log_prefix = false; - static bool is_glog_initialzed = false; - if (!is_glog_initialzed) { -#if !defined(_WIN32) && !defined(_WIN64) - google::InitGoogleLogging("mindspore"); -#endif - is_glog_initialzed = true; - } // set default log level to WARNING if (mindspore::GetEnv("GLOG_v").empty()) { FLAGS_v = mindspore::WARNING; @@ -525,4 +517,22 @@ void mindspore_log_init(void) { #endif mindspore::InitSubModulesLogLevel(); } + +// shared lib init hook +#if defined(_WIN32) || defined(_WIN64) +__attribute__((constructor)) void mindspore_log_init(void) { +#else +void mindspore_log_init(void) { +#endif +#ifdef USE_GLOG + static bool is_glog_initialzed = false; + if (!is_glog_initialzed) { +#if !defined(_WIN32) && !defined(_WIN64) + google::InitGoogleLogging("mindspore"); +#endif + is_glog_initialzed = true; + } +#endif + common_log_init(); +} } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 390c210095c021d6b8c230d75b4441493a5aaf94..8d0f729e50cafea04d32811285be4d3e05bb1728 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode"; // index define of depend constexpr auto kRealInputIndexInDepend = 1; constexpr auto kDependAttachNodeIndex = 2; +constexpr auto kDependInputSize = 3; // format constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 728e10f3bb412722b1c797671c476fc48a21a669..92c600520f9176096e17d362e8ef14260eb51c76 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -22,6 +22,10 @@ from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry __all__ = ['Tensor', 'MetaTensor'] +np_types = (np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, np.float16, + np.float32, np.float64, np.bool_) + class Tensor(Tensor_): @@ -54,6 +58,10 @@ class Tensor(Tensor_): """ def __init__(self, input_data, dtype=None): + # If input data is numpy number, convert it to np array + if isinstance(input_data, np_types): + input_data = np.array(input_data) + # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. check_type('tensor input_data', input_data, (Tensor_, float, int)) if dtype is not None: diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ca6f7ca33e5bf3577abe8d867f125e31b0053f3e..360cdb1860e9f0e0257c2991af5b01c096b4fb47 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1040,7 +1040,7 @@ class Dataset: Args: columns (list[str], optional): List of columns to be used to specify the order of columns - (defaults=None, means all columns). + (default=None, means all columns). Returns: Iterator, list of ndarray. @@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset): class_indexing (dict, optional): A str-to-int mapping from label name to index (default=None, the folder names will be sorted alphabetically and each class will be given a unique index starting from 0). - decode (bool, optional): decode the images after reading (defaults=False). + decode (bool, optional): decode the images after reading (default=False). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This @@ -4760,7 +4760,7 @@ class _NumpySlicesDataset: def process_dict(self, input_data): """ - Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first. + Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. """ # Convert pandas like dict(has "values" column) into General dict data_keys = list(input_data.keys()) diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index aef714953f06e030ba0f5815224a0ffa80677ce0..3fdf7795d0f3758267d8c68c5fe1e22c9ab6d4f6 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp): Flip the input image horizontally, randomly with a given probability. Args: - prob (float): Probability of the image being flipped (default=0.5). + prob (float, optional): Probability of the image being flipped (default=0.5). """ @check_prob @@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp): Maintains data integrity by also flipping bounding boxes in an object detection pipeline. Args: - prob (float): Probability of the image being flipped (default=0.5). + prob (float, optional): Probability of the image being flipped (default=0.5). """ @check_prob @@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp): Flip the input image vertically, randomly with a given probability. Args: - prob (float): Probability of the image being flipped (default=0.5). + prob (float, optional): Probability of the image being flipped (default=0.5). """ @check_prob diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index ba56af6219378cd76144e4c1dfb9de5e157cdbef..b3d4bc8a0ed0b74e369cee4d1892244d0d979784 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -29,8 +29,9 @@ from .optimizer import Optimizer _adam_opt = C.MultitypeFuncGraph("adam_opt") -@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. @@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad m (Tensor): m value of parameters. v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. + decay_flag (bool): Applies weight decay or not. + optim_filter (bool): Applies parameter update or not. Returns: Tensor, the new value of v after updating. """ - op_mul = P.Mul() - op_square = P.Square() - op_sqrt = P.Sqrt() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() + if optim_filter: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta1, gradient_fp32) - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - - beta2, op_square(gradient_fp32)) + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) - update = next_m / (eps + op_sqrt(next_v)) - if decay_flag: - update = op_mul(weight_decay_tensor, param_fp32) + update - update_with_lr = op_mul(lr, update) - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + update = next_m / (eps + op_sqrt(next_v)) + if decay_flag: + update = op_mul(weight_decay_tensor, param_fp32) + update - next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) - next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) - next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v)))) - return next_v + update_with_lr = op_mul(lr, update) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) + return next_param + return gradient def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): @@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer): - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. + tuple[bool], all elements are True. Examples: >>> net = Net() @@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer): def construct(self, gradients): lr = self.get_lr() - updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) - - return updated_velocity + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, + self.decay_flag, self.optim_filter) + if self.use_parallel: + optim_result = self.broadcast_params(optim_result) + return optim_result class AdamWeightDecayDynamicLR(Optimizer): @@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer): - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. + tuple[bool], all elements are True. Examples: >>> net = Net() @@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer): warmup_lr = self.start_learning_rate * warmup_percent is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr - updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) - + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, + self.decay_flag, self.optim_filter) + if self.use_parallel: + optim_result = self.broadcast_params(optim_result) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step - return updated_velocity + return optim_result diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 832b35d66f15a0b209973539d98b40ddfc59013e..93c7edbce844829a907451e7983e35ae19a8e7f9 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32) _lamb_opt = C.MultitypeFuncGraph("lamb_opt") - -@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, - gradient, decay_flag): + gradient, decay_flag, optim_filter): """ Update parameters. @@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. decay_flag (bool): Specifies whether param update with weight decay. + optim_filter(bool): Applies parameter update or not. Returns: Tensor, the new value of v after updating. """ - op_mul = P.Mul() - op_sqrt = P.Sqrt() - op_rsqrt = P.Rsqrt() - op_square = P.Square() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() - op_pow = P.Pow() - op_norm = layer.Norm() - op_select = P.Select() - op_greater = P.Greater() - op_fill = P.Fill() - op_dtype = P.DType() - - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) - - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, - mstype.float32) - beta1, gradient_fp32) - - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, - mstype.float32) - beta2, op_square(gradient_fp32)) - - next_mm = next_m / (op_cast(num_one, mstype.float32) - - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) - next_vv = next_v / (op_cast(num_one, mstype.float32) - - op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) - w_norm = op_norm(param_fp32) - g_norm = op_norm(gradient_fp32) - - g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt( - next_vv + eps)) + weight_decay_tensor * param_fp32) - zeros = F.zeros_like(w_norm) - ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) - trust_ratio = op_select( - op_greater(w_norm, zeros), - op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), - ones) - tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) - trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) - update = next_mm / (op_sqrt(next_vv) + eps) - - if decay_flag: - update = update + op_mul(weight_decay_tensor, param_fp32) - - update_with_lr = op_mul(op_mul(trust_ratio, lr), update) - - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - - next_v = F.depend(next_v, F.assign(param, next_param)) - next_v = F.depend(next_v, F.assign(m, next_m)) - next_v = F.depend(next_v, F.assign(v, next_v)) - - return next_v + if optim_filter: + op_mul = P.Mul() + op_sqrt = P.Sqrt() + op_rsqrt = P.Rsqrt() + op_square = P.Square() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + op_pow = P.Pow() + op_norm = layer.Norm() + op_select = P.Select() + op_greater = P.Greater() + op_fill = P.Fill() + op_dtype = P.DType() + + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) + + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) + + next_mm = next_m / (op_cast(num_one, mstype.float32) + - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) + next_vv = next_v / (op_cast(num_one, mstype.float32) - + op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) + w_norm = op_norm(param_fp32) + g_norm = op_norm(gradient_fp32) + + g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) + zeros = F.zeros_like(w_norm) + ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) + trust_ratio = op_select( + op_greater(w_norm, zeros), + op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), + ones) + tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) + trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) + update = next_mm / (op_sqrt(next_vv) + eps) + + if decay_flag: + update = update + op_mul(weight_decay_tensor, param_fp32) + + update_with_lr = op_mul(op_mul(trust_ratio, lr), update) + + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_param = F.depend(next_param, F.assign(param, next_param)) + next_param = F.depend(next_param, F.assign(m, next_m)) + next_param = F.depend(next_param, F.assign(v, next_v)) + + return next_param + return gradient lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") @@ -238,7 +237,7 @@ class Lamb(Optimizer): - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. Outputs: - tuple[Parameter], the updated velocity value, the shape is the same as `params`. + tuple[bool], all elements are True. Examples: >>> net = Net() @@ -311,18 +310,21 @@ class Lamb(Optimizer): self.warmup_steps, self.global_step), mstype.float32) lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr if self.enable_graph_kernel: - updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) + optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, + self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor, self.global_step), + self.params, self.moments1, self.moments2, gradients, self.decay_flag) else: - updated_velocity = self.hyper_map(F.partial(_lamb_opt, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay_tensor, self.global_step), - self.params, self.moments1, self.moments2, gradients, self.decay_flag) + optim_result = self.hyper_map(F.partial(_lamb_opt, + self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor, self.global_step), + self.params, self.moments1, self.moments2, gradients, + self.decay_flag, self.optim_filter) + if self.use_parallel: + optim_result = self.broadcast_params(optim_result) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step - return updated_velocity + return optim_result diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 9bfc3a284b61a75b490e7537992ce7f964eb3d76..a811edcabc54ecbecb29a066dcddcc77f5de0e79 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer +from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel -from mindspore.common.tensor import Tensor from mindspore import log as logger +from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.train.parallel_utils import ParallelMode __all__ = ['Optimizer'] @@ -155,6 +158,27 @@ class Optimizer(Cell): self.param_length = len(self.parameters) self.map_ = C.Map() + use_parallel = auto_parallel_context().get_enable_parallel_optimizer() + self.use_parallel = use_parallel + if use_parallel: + if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: + raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) + if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, + ParallelMode.AUTO_PARALLEL]: + raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format + (_get_parallel_mode())) + self.dev_num = _get_device_num() + if self.dev_num > self.param_length: + raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" + " less than the number of devices {}".format(self.param_length, self.dev_num)) + self.param_rank = self._get_parameter_group_id() + self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) + self.param_names = [] + for param in self.parameters: + self.param_names.append(param.name) + else: + self.optim_filter = (True,) * self.param_length + def decay_weight(self, gradients): """ Weight decay. @@ -219,8 +243,32 @@ class Optimizer(Cell): raise TypeError("Learning rate should be float, Tensor or Iterable.") return lr + def _check_group_params(self, parameters): + """Check group params.""" + parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] + for group_param in parameters: + invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) + if invalid_key: + raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') + + if 'order_params' in group_param.keys(): + if len(group_param.keys()) > 1: + raise ValueError("The order params dict in group parameters should " + "only include the 'order_params' key.") + if not isinstance(group_param['order_params'], Iterable): + raise TypeError("The value of 'order_params' should be an Iterable type.") + continue + + if not group_param['params']: + raise ValueError("Optimizer got an empty group parameter list.") + + for param in group_param['params']: + if not isinstance(param, Parameter): + raise TypeError("The group param should be an iterator of Parameter type.") + def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" + self._check_group_params(parameters) if self.dynamic_lr: dynamic_lr_length = learning_rate.size() else: @@ -250,9 +298,6 @@ class Optimizer(Cell): if dynamic_lr_length not in (lr_length, 0): raise ValueError("The dynamic learning rate in group should be the same size.") - if not group_param['params']: - raise ValueError("Optimizer got an empty group parameter list.") - dynamic_lr_length = lr_length self.dynamic_lr_length = dynamic_lr_length @@ -384,6 +429,51 @@ class Optimizer(Cell): lr = self.learning_rate return lr + def _get_parameter_group_id(self): + """ + Get the parameter partition group id, which is less than the number of devices. + + Returns: + tuple, the group id tuple of parameters. + """ + rank_list = () + count = 0 + for _ in range(self.param_length): + rank_list = rank_list + (count,) + count = count + 1 + if count == self.dev_num: + count = 0 + return rank_list + + def broadcast_params(self, optim_result): + """ + Apply Broadcast operations in the sequential order of parameter groups. + + Returns: + bool, the status flag. + """ + param_group = [] + key_group = [] + for _ in range(self.dev_num): + param_group.append(F.make_tuple()) + key_group.append(F.make_tuple()) + for i in range(self.param_length): + param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) + key = P.MakeRefKey(self.param_names[i])() + key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) + new_param_group = [] + for root in range(self.dev_num): + ops = P.Broadcast(root) + next_params = ops(param_group[root]) + new_param_group.append(next_params) + for i in range(F.tuple_len(next_params)): + F.assign(key_group[root][i], next_params[i]) + status = True + for i in range(self.dev_num - 1): + status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) + + return status + def construct(self, *hyper_params): raise NotImplementedError diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index f0d920f51fa430c7b86cb3afbe40fa865794a472..9e3d00cc9591d4b13ef1e280952b07e3e3c9cefc 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -220,7 +220,9 @@ class DataWrapper(Cell): def __init__(self, network, dataset_types, dataset_shapes, queue_name): super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) - + # Also copy the flag in `network` construct + flags = getattr(network.__class__.construct, "_mindspore_flags", {}) + self.add_flags(**flags) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) self.network = network diff --git a/mindspore/ops/_op_impl/akg/__init__.py b/mindspore/ops/_op_impl/akg/__init__.py index f38b99f5e4f02f75bcff0c0a147761e77a013383..fd86dbf999160ecbced337b9f2427caaece7fd28 100644 --- a/mindspore/ops/_op_impl/akg/__init__.py +++ b/mindspore/ops/_op_impl/akg/__init__.py @@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg from .less import _less_akg from .log import _log_akg from .matmul import _matmul_akg +from .batchmatmul import _batchmatmul_akg from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg from .max_pool_with_argmax import _max_pool_with_argmax_akg from .max import _max_akg diff --git a/mindspore/ops/_op_impl/akg/batchmatmul.py b/mindspore/ops/_op_impl/akg/batchmatmul.py new file mode 100644 index 0000000000000000000000000000000000000000..f5da71aa25e7634bb6515ef7516d73d21daf371e --- /dev/null +++ b/mindspore/ops/_op_impl/akg/batchmatmul.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchMatMul op""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "BatchMatMul", + "imply_type": "AutoDiff", + "fusion_type": "OPAQUE", + "attr": [ + { + "name": "transpose_a", + "param_type": "optional", + "type": "bool" + }, + { + "name": "transpose_b", + "param_type": "optional", + "type": "bool" + } + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x1" + }, + { + "index": 1, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x2" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "output" + } + ] +}""") +def _batchmatmul_akg(): + """BatchMatMul AKG register""" + return diff --git a/mindspore/ops/_op_impl/tbe/confusion_transpose_d.py b/mindspore/ops/_op_impl/tbe/confusion_transpose_d.py index e52ae01520b11cd7247eaa04da629691ceb4f0c0..f84e9d4292d33811c53f9c3c29ca769cfda0fafb 100644 --- a/mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +++ b/mindspore/ops/_op_impl/tbe/confusion_transpose_d.py @@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \ .attr("transpose_first", "required", "bool", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None) \ .get_op_info() diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 38cf0141f09c9b9d058db049b3b51ead579cde2a..30c943b69f9e5b77af1ff17d756acf9abec52aa6 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value): return F.list_setitem(data, number_index, value) +@setitem.register("List", "Number", "Tuple") +def _list_setitem_with_Tuple(data, number_index, value): + """ + Assigns value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (list): Value given. + + Outputs: + list, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + @setitem.register("Dictionary", "String", "Tensor") def _dict_setitem_with_tensor(data, key, value): """ diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 1b212d161a08916a09e95079b4dd89a87a342727..f8b47a28c3fbce09ff97af78454ffeba0356689d 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer): self.op = op self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) + self.add_prim_attr('index', 0) def vm_impl(self, x): """Implement by vm mode.""" diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index bafc72897e6b881933288bc05405688913573e0a..066791d4df5afe13562ba86870bf8e4ddc1558ad 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer): Output tensor or string to stdout. Note: - The print operation cannot support the following cases currently. - - 1. The type of tensor is float64 or bool. - - 2. The data of tensor is a scalar type. - In pynative mode, please use python print function. Inputs: @@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer): @prim_attr_register def __init__(self): - pass + self.add_prim_attr("_side_effect", True) def __call__(self, *args): for arg in args: diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index af0deacc16f9f9729a5122cef96848af994dd994..b6aed4d79c34f61f5eb9748a6a946d23759aafa4 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer): def infer_value(self, input_x): if input_x is not None: input_x = input_x.asnumpy() - return Tensor(-input_x) + out = np.array(-input_x, input_x.dtype) + return Tensor(out) return None @@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp): if x is not None and y is not None: x = x.asnumpy() y = y.asnumpy() - return Tensor(x / y) + out = np.array(x / y, x.dtype) + return Tensor(out) return None diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 74c6080ab4132ebc8f730db348a1e2e5a699b0f8..b6b938d800bc11410b3146edf6926104fc6bb1e1 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer): return variable def infer_dtype(self, variable, value): - args = {"variable": variable, "value": value} - validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + # Add a type validation later when we don't have to assign a value to RefKey. return variable diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 74250f12e5ad03cc3b3e193af8134d482a1e384f..93fe23385575096d58bb243605692c93806cec60 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -400,6 +400,23 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_global_rank_is_set() + def set_enable_parallel_optimizer(self, enable_parallel_optimizer): + """ + Set enable/disable parallel optimizer. + + Args: + set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer. + """ + self.check_context_handle() + if not isinstance(enable_parallel_optimizer, bool): + raise TypeError('enable_parallel_optimizer is invalid type') + self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) + + def get_enable_parallel_optimizer(self): + """Get parallel optimizer flag.""" + self.check_context_handle() + return self._context_handle.get_enable_parallel_optimizer() + def reset(self): """Reset all settings.""" self.check_context_handle() @@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = { "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, - "full_batch": auto_parallel_context().set_full_batch} + "full_batch": auto_parallel_context().set_full_batch, + "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer} _get_auto_parallel_context_func_map = { @@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = { "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, - "full_batch": auto_parallel_context().get_full_batch} + "full_batch": auto_parallel_context().get_full_batch, + "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, - strategy_ckpt_save_file=str, full_batch=bool) + strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) + def _set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. + enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. @@ -535,5 +556,6 @@ def _reset_auto_parallel_context(): - parameter_broadcast: False. - strategy_ckpt_load_file: "" - strategy_ckpt_save_file: "" + - enable_parallel_optimizer: False """ auto_parallel_context().reset() diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 6d5ec45d5bfff418580bcc084994cb995d09fba7..cff03ca398f2be708ded48d943d902a52932b990 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -166,8 +166,11 @@ class SummaryCollector(Callback): self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True + self._dataset_sink_mode = True def __enter__(self): + self._first_step = True + self._dataset_sink_mode = True self._record = SummaryRecord(log_dir=self._summary_dir) return self @@ -279,15 +282,15 @@ class SummaryCollector(Callback): def step_end(self, run_context): cb_params = run_context.original_args() + if self._first_step: + # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario + self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) if cb_params.mode == ModeEnum.TRAIN.value: - # Make sure the first step data is recorded - if not self._first_step and cb_params.cur_step_num % self._collect_freq: + if not self._is_collect_this_step(cb_params): return - self._first_step = False - if not self._has_saved_train_network: self._collect_graphs(cb_params) @@ -295,6 +298,7 @@ class SummaryCollector(Callback): self._collect_metric(cb_params) self._collect_histogram(cb_params) + self._first_step = False self._record.record(cb_params.cur_step_num) def end(self, run_context): @@ -320,6 +324,18 @@ class SummaryCollector(Callback): raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," f"but expected only one {self.__class__.__name__} instance.") + def _is_collect_this_step(self, cb_params): + """Decide whether to collect data for the current step.""" + # Make sure the first step data is recorded + if not self._first_step: + if self._dataset_sink_mode: + if cb_params.cur_epoch_num % self._collect_freq: + return False + else: + if cb_params.cur_step_num % self._collect_freq: + return False + return True + @staticmethod def _package_custom_lineage_data(custom_lineage_data): """ diff --git a/model_zoo/faster_rcnn/src/dataset.py b/model_zoo/faster_rcnn/src/dataset.py index d64de0939190cc951f4e605ff39429ddddc1c570..133824dd247fd96dc9edd2fd73e3e4427f8116b9 100644 --- a/model_zoo/faster_rcnn/src/dataset.py +++ b/model_zoo/faster_rcnn/src/dataset.py @@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training): else: input_data = resize_column(*input_data) - photo = (np.random.rand() < config.photo_ratio) - if photo: - input_data = photo_crop_column(*input_data) - input_data = image_bgr_rgb(*input_data) output_data = input_data @@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast writer.write_raw_data([row]) writer.commit() + def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0, - is_training=True, num_parallel_workers=8): + is_training=True, num_parallel_workers=4): """Creatr FasterRcnn dataset with MindDataset.""" ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id, - num_parallel_workers=num_parallel_workers, shuffle=is_training) + num_parallel_workers=1, shuffle=is_training) decode = C.Decode() - ds = ds.map(input_columns=["image"], operations=decode) + ds = ds.map(input_columns=["image"], operations=decode, num_parallel_workers=1) compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) hwc_to_chw = C.HWC2CHW() normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) horizontally_op = C.RandomHorizontalFlip(1) - type_cast0 = CC.TypeCast(mstype.float32) type_cast1 = CC.TypeCast(mstype.float16) type_cast2 = CC.TypeCast(mstype.int32) type_cast3 = CC.TypeCast(mstype.bool_) @@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi ds = ds.map(input_columns=["image", "annotation"], output_columns=["image", "image_shape", "box", "label", "valid_num"], columns_order=["image", "image_shape", "box", "label", "valid_num"], - operations=compose_map_func, num_parallel_workers=4) - - ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], - num_parallel_workers=num_parallel_workers) + operations=compose_map_func, num_parallel_workers=num_parallel_workers) flip = (np.random.rand() < config.flip_ratio) if flip: - ds = ds.map(input_columns=["image"], operations=[horizontally_op], - num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1], + num_parallel_workers=24) ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], - operations=flipped_generation, num_parallel_workers=4) + operations=flipped_generation, num_parallel_workers=num_parallel_workers) + else: + ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1], + num_parallel_workers=24) + else: ds = ds.map(input_columns=["image", "annotation"], output_columns=["image", "image_shape", "box", "label", "valid_num"], @@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi operations=compose_map_func, num_parallel_workers=num_parallel_workers) - ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], - num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1], + num_parallel_workers=24) # transpose_column from python to c - ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) ds = ds.map(input_columns=["box"], operations=[type_cast1]) ds = ds.map(input_columns=["label"], operations=[type_cast2]) diff --git a/model_zoo/vgg16/src/config.py b/model_zoo/vgg16/src/config.py index 8c6ffee98b49df67572fabcb107441ef2b3f0f1e..a34cf7a1d3ee61a13dd71e4bed4d42d9982a04b9 100644 --- a/model_zoo/vgg16/src/config.py +++ b/model_zoo/vgg16/src/config.py @@ -19,7 +19,9 @@ from easydict import EasyDict as edict cifar_cfg = edict({ 'num_classes': 10, - 'lr_init': 0.05, + 'lr_init': 0.01, + 'lr_max': 0.1, + 'warmup_epochs': 5, 'batch_size': 64, 'epoch_size': 70, 'momentum': 0.9, diff --git a/model_zoo/vgg16/train.py b/model_zoo/vgg16/train.py index c582cdd679dfefe7bcb644d3431d0aeca10adedd..33a4f0310c9d09c02a764bbe8d7ef5ef79f1ddfc 100644 --- a/model_zoo/vgg16/train.py +++ b/model_zoo/vgg16/train.py @@ -38,20 +38,25 @@ random.seed(1) np.random.seed(1) -def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): +def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): """Set learning rate.""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs - decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + warmup_steps = steps_per_epoch * warmup_epochs + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 for i in range(total_steps): - if i < decay_epoch_index[0]: - lr_each_step.append(lr_max) - elif i < decay_epoch_index[1]: - lr_each_step.append(lr_max * 0.1) - elif i < decay_epoch_index[2]: - lr_each_step.append(lr_max * 0.01) + if i < warmup_steps: + lr_value = float(lr_init) + inc_each_step * float(i) else: - lr_each_step.append(lr_max * 0.001) + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr_value = float(lr_max) * base * base + if lr_value < 0.0: + lr_value = 0.0 + lr_each_step.append(lr_value) + current_step = global_step lr_each_step = np.array(lr_each_step).astype(np.float32) learning_rate = lr_each_step[current_step:] @@ -86,7 +91,8 @@ if __name__ == '__main__': if args_opt.pre_trained: load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) - lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) + lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, + total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) diff --git a/serving/core/server.cc b/serving/core/server.cc index add9d16bee557c4304e39bb304b2d50fc6084a0c..c07558a5c2d0d61a1eea926d77cfa866732c4e5f 100644 --- a/serving/core/server.cc +++ b/serving/core/server.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "mindspore/ccsrc/utils/log_adapter.h" #include "serving/ms_service.grpc.pb.h" @@ -40,7 +41,7 @@ namespace serving { using MSTensorPtr = std::shared_ptr; Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { - session_ = inference::MSSession::CreateSession(device + "Inference", device_id); + session_ = inference::MSSession::CreateSession(device, device_id); if (session_ == nullptr) { MS_LOG(ERROR) << "Creat Session Failed"; return FAILED; @@ -67,6 +68,7 @@ Status Session::Predict(const std::vector &inputs, inference::Multi MS_LOG(INFO) << "run Predict"; *outputs = session_->RunGraph(graph_id_, inputs); + MS_LOG(INFO) << "run Predict finished"; return SUCCESS; } @@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) { std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); char *graphBuf = ReadFile(file_name.c_str(), &size); if (graphBuf == nullptr) { - MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); return FAILED; } last_graph_ = inference::LoadModel(graphBuf, size, device_type_); + if (last_graph_ == nullptr) { + MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + return FAILED; + } graph_id_ = session_->CompileGraph(last_graph_); - MS_LOG(INFO) << "Session Warmup"; + MS_LOG(INFO) << "Session Warmup finished"; return SUCCESS; } @@ -95,6 +101,9 @@ Status Session::Clear() { } namespace { +static const uint32_t uint32max = 0x7FFFFFFF; +std::promise exit_requested; + const std::map type2id_map{ {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, @@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { } TypeId type = iter->second; auto ms_tensor = std::shared_ptr(inference::MSTensor::CreateTensor(type, shape)); - memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size()); + memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size()); return ms_tensor; } @@ -166,10 +175,7 @@ void ClearEnv() { Session::Instance().Clear(); inference::ExitInference(); } -void HandleSignal(int sig) { - ClearEnv(); - exit(0); -} +void HandleSignal(int sig) { exit_requested.set_value(); } #ifdef ENABLE_D static rtContext_t g_ctx = nullptr; @@ -247,6 +253,7 @@ Status Server::BuildAndStart() { rtError_t rt_ret = rtCtxGetCurrent(&ctx); if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { MS_LOG(ERROR) << "the ascend device context is null"; + ClearEnv(); return FAILED; } g_ctx = ctx; @@ -258,6 +265,7 @@ Status Server::BuildAndStart() { auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); grpc::ServerBuilder builder; builder.SetOption(std::move(option)); + builder.SetMaxMessageSize(uint32max); // Listen on the given address without any authentication mechanism. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); // Register "service" as the instance through which we'll communicate with @@ -265,13 +273,20 @@ Status Server::BuildAndStart() { builder.RegisterService(&service); // Finally assemble the server. std::unique_ptr server(builder.BuildAndStart()); + if (server == nullptr) { + MS_LOG(ERROR) << "The serving server create failed"; + ClearEnv(); + return FAILED; + } + auto grpc_server_run = [&server]() { server->Wait(); }; + std::thread serving_thread(grpc_server_run); MS_LOG(INFO) << "Server listening on " << server_address << std::endl; - - // Wait for the server to shutdown. Note that some other thread must be - // responsible for shutting down the server for this call to ever return. - server->Wait(); + auto exit_future = exit_requested.get_future(); + exit_future.wait(); + ClearEnv(); + server->Shutdown(); + serving_thread.join(); return SUCCESS; } - } // namespace serving } // namespace mindspore diff --git a/serving/core/util/file_system_operation.cc b/serving/core/util/file_system_operation.cc index a5143995dec7143c44a89591f37fe6439c0f59fa..1af512a54c02a947e0c9ee904661e02864407b3c 100644 --- a/serving/core/util/file_system_operation.cc +++ b/serving/core/util/file_system_operation.cc @@ -29,7 +29,6 @@ namespace mindspore { namespace serving { - char *ReadFile(const char *file, size_t *size) { if (file == nullptr) { MS_LOG(ERROR) << "file is nullptr"; @@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) { } std::vector GetAllSubDirs(const std::string &dir_path) { - DIR *dir; - struct dirent *ptr; + DIR *dir = nullptr; + struct dirent *ptr = nullptr; std::vector SubDirs; if ((dir = opendir(dir_path.c_str())) == NULL) { diff --git a/serving/core/util/option_parser.cc b/serving/core/util/option_parser.cc index 9cbd7eaee8f01bc7b68888ae088cad5a4a117692..c7f00e37338ccdc7ec61ba9e4f3456c015d70cac 100644 --- a/serving/core/util/option_parser.cc +++ b/serving/core/util/option_parser.cc @@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) { bool Option::ParseInt32(std::string *arg) { if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { - char extra; int32_t parsed_value; - if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) { - std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; + try { + parsed_value = std::stoi(arg->data()); + } catch (std::invalid_argument) { + std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; return false; - } else { - *int32_default_ = parsed_value; } + *int32_default_ = parsed_value; return true; } - return false; } @@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) { bool Option::ParseFloat(std::string *arg) { if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { - char extra; float parsed_value; - if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) { - std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; + try { + parsed_value = std::stof(arg->data()); + } catch (std::invalid_argument) { + std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; return false; - } else { - *float_default_ = parsed_value; } + *float_default_ = parsed_value; return true; } - return false; } @@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); } void Options::CreateOptions() { args_ = std::make_shared(); std::vector