提交 5a886794 编写于 作者: Y yanghaitao

Merge branch 'master' of gitee.com:mindspore/mindspore

akg @ df57a6cf
Subproject commit c460176523d039c8995f1d71089753725ebc0792 Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35
...@@ -277,10 +277,11 @@ endif () ...@@ -277,10 +277,11 @@ endif ()
if (USE_GLOG) if (USE_GLOG)
target_link_libraries(inference PRIVATE mindspore::glog) target_link_libraries(inference PRIVATE mindspore::glog)
else()
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()
endif() endif()
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,common_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()
...@@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow ...@@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK(input); BOUNDING_BOX_CHECK(input);
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal");
(*output).push_back(nullptr); // init memory for return vector output->resize(2);
(*output).push_back(nullptr);
(*output)[1] = std::move(input[1]); // move boxes over to output (*output)[1] = std::move(input[1]); // move boxes over to output
size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor
......
...@@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) ...@@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
int32_t padded_image_h; int32_t padded_image_h;
int32_t padded_image_w; int32_t padded_image_w;
(*output).push_back(nullptr); output->resize(2);
(*output).push_back(nullptr);
(*output)[1] = std::move(input[1]); // since some boxes may be removed (*output)[1] = std::move(input[1]); // since some boxes may be removed
bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches
......
...@@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow * ...@@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *
RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y));
} }
(*output).push_back(nullptr); output->resize(2);
(*output).push_back(nullptr);
(*output)[1] = std::move(input[1]); (*output)[1] = std::move(input[1]);
return VerticalFlip(input[0], &(*output)[0]); return VerticalFlip(input[0], &(*output)[0]);
......
...@@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" ...@@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(utils OBJECT add_library(utils OBJECT
arena.cc arena.cc
buddy.cc
cache_pool.cc
circular_pool.cc circular_pool.cc
memory_pool.cc memory_pool.cc
cond_var.cc cond_var.cc
...@@ -11,7 +13,11 @@ add_library(utils OBJECT ...@@ -11,7 +13,11 @@ add_library(utils OBJECT
service.cc service.cc
services.cc services.cc
lock.cc lock.cc
semaphore.cc
status.cc status.cc
storage_container.cc
storage_manager.cc
slice.cc
path.cc path.cc
wait_post.cc wait_post.cc
sig_handler.cc) sig_handler.cc)
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
#define DATASET_UTIL_ALLOCATOR_H_ #define DATASET_UTIL_ALLOCATOR_H_
#include <cstdlib> #include <cstdlib>
#include <functional>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility>
#include "dataset/util/memory_pool.h" #include "dataset/util/memory_pool.h"
namespace mindspore { namespace mindspore {
...@@ -84,6 +86,91 @@ class Allocator { ...@@ -84,6 +86,91 @@ class Allocator {
private: private:
std::shared_ptr<MemoryPool> pool_; std::shared_ptr<MemoryPool> pool_;
}; };
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
/// Default to std::allocator
template <typename T, typename C = std::allocator<T>>
class MemGuard {
public:
using allocator = C;
MemGuard() : n_(0) {}
explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
// There is no copy constructor nor assignment operator because the memory is solely owned by this object.
MemGuard(const MemGuard &) = delete;
MemGuard &operator=(const MemGuard &) = delete;
// On the other hand, We can support move constructor
MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
MemGuard &operator=(MemGuard &&lhs) noexcept {
if (this != &lhs) {
this->deallocate();
n_ = lhs.n_;
alloc_ = std::move(lhs.alloc_);
ptr_ = std::move(lhs.ptr_);
}
return *this;
}
/// \brief Explicitly deallocate the memory if allocated
void deallocate() {
if (ptr_) {
auto *p = ptr_.release();
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
for (auto i = 0; i < n_; ++i) {
p[i].~T();
}
}
alloc_.deallocate(p, n_);
n_ = 0;
}
}
/// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
/// allocated.
/// \param n Number of objects of type T to be allocated
/// \tparam Args Extra arguments pass to the constructor of T
template <typename... Args>
Status allocate(size_t n, Args &&... args) noexcept {
try {
deallocate();
if (n > 0) {
T *data = alloc_.allocate(n);
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...);
}
}
ptr_ = std::unique_ptr<T[]>(data);
n_ = n;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
~MemGuard() noexcept { deallocate(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetPointer() const { return ptr_.get(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetMutablePointer() { return ptr_.get(); }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) { return GetMutablePointer() + x; }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) const { return GetPointer() + x; }
/// \brief Return how many bytes are allocated in total
/// \return Number of bytes allocated in total
size_t GetSizeInBytes() const { return n_ * sizeof(T); }
private:
allocator alloc_;
std::unique_ptr<T[], std::function<void(T *)>> ptr_;
size_t n_;
};
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
......
...@@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> { ...@@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
} }
private: private:
static constexpr key_type kMinKey = 1; static constexpr key_type kMinKey = 0;
std::atomic<key_type> inx_; std::atomic<key_type> inx_;
}; };
} // namespace dataset } // namespace dataset
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/buddy.h"
#include <iomanip>
#include <stdexcept>
#include "dataset/util/de_error.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/system_pool.h"
#include "./securec.h"
inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); }
inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); }
inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; }
inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; }
inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; }
namespace mindspore {
namespace dataset {
Status BuddySpace::Init() {
if (log_min_ < 0) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"log_min must be positive : " + std::to_string(log_min_));
}
if (num_lvl_ < 3 || num_lvl_ > 18) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_));
}
min_ = BitLeftShift(1, log_min_);
max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1);
size_t offset_1 = sizeof(rel_addr_t) * num_lvl_;
size_t offset_2 = sizeof(int) * num_lvl_ + offset_1;
size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2;
RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true));
hint_ = reinterpret_cast<rel_addr_t *>(ptr_);
count_ = reinterpret_cast<int *>((reinterpret_cast<char *>(ptr_) + offset_1));
map_ = reinterpret_cast<char *>(ptr_) + offset_2;
count_[num_lvl_ - 1] = 1;
map_[0] = BitOr(MORE_BIT, num_lvl_ - 3);
return Status::OK();
}
Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept {
std::lock_guard<std::mutex> lock(mutex_);
addr_t addr = AllocNoLock(sz, desc);
if (addr != NOSPACE) {
*p = addr;
return Status::OK();
} else {
return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore.");
}
}
addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept {
DS_ASSERT(sz <= max_);
uint32_t reqSize = SizeToBlock(sz);
rel_addr_t rel_addr = AllocBuddySeg(reqSize);
if (rel_addr != static_cast<rel_addr_t>(NOSPACE)) {
(void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor));
desc->sig = static_cast<int>(0xDEADBEEF);
desc->addr = rel_addr;
desc->req_size = reqSize;
desc->blk_size = NextPowerOf2(reqSize);
return static_cast<addr_t>(rel_addr * min_);
} else {
return NOSPACE;
}
}
void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) {
DS_ASSERT(desc->sig == 0XDEADBEEF);
rel_addr_t rel_addr = desc->addr;
size_t blk_size = desc->blk_size;
size_t req_size = desc->req_size;
FreeBuddySeg(rel_addr, blk_size, req_size);
}
void BuddySpace::Free(const BSpaceDescriptor *desc) {
std::lock_guard<std::mutex> lock(mutex_);
return FreeNoLock(desc);
}
std::ostream &operator<<(std::ostream &os, const BuddySpace &s) {
os << "1 unit = " << s.GetMinSize() << "\n"
<< "Size of buddy space = " << s.GetMaxSize() << "\n"
<< "Number of levels = " << s.num_lvl_ << "\n\n"
<< "Percent free = " << s.PercentFree() << "\n"
<< "Dumping count array : "
<< "\n";
for (int i = 0; i < s.num_lvl_; i++) {
os << "[" << i << "] = " << s.count_[i] << " ";
if (((i + 1) % 4) == 0) {
os << "\n";
}
}
os << "\n";
os << "Dumping allocation info:"
<< "\n";
auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, s.num_lvl_ - 1));
rel_addr_t addr = 0;
while (addr < max_addr) {
size_t sz = 0;
BuddySpace::STATE st;
s.GetBuddySegState(addr, &sz, &st);
os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : "
<< ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn"))
<< "\n";
addr += sz;
}
return os;
}
void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const {
char byte;
int pos;
int offset;
uint64_t val = 0;
int shift;
pos = BitRightShift(rel_addr, 2);
offset = rel_addr % 4;
shift = offset * 2;
byte = map_[pos];
switch (offset) {
case 0:
val = byte;
break;
case 1:
case 3:
if (offset == 1) {
val = BitLeftShift(BitAnd(byte, 0x30), shift);
} else {
val = BitLeftShift(BitAnd(byte, 0x03), shift);
}
break;
case 2:
val = BitLeftShift(BitAnd(byte, 0x0F), shift);
break;
}
if (BitAnd(val, ONE_BIT)) {
*rel_sz = 1;
} else if (BitAnd(val, TWO_BIT)) {
*rel_sz = 2;
} else if (BitAnd(val, MORE_BIT)) {
log_t lg = BitAnd(val, 0x0F);
*rel_sz = BitLeftShift(1, lg + 2);
} else {
*st = STATE::kEmpty;
return;
}
*st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree;
}
void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) {
int clr;
int mask;
int pos;
int offset;
int val = 0;
int shift;
auto log_sz = static_cast<log_t>(Log2(rel_sz));
pos = BitRightShift(rel_addr, 2);
offset = rel_addr % 4;
shift = offset * 2;
if (rel_sz == 1) {
val = ONE_BIT;
mask = 0xC0;
} else if (rel_sz == 2) {
val = TWO_BIT;
mask = 0xF0;
} else {
val = BitOr(log_sz - 2, MORE_BIT);
mask = 0xFF;
}
if (st == STATE::kAlloc) {
val = BitOr(val, ALLOC_BIT);
} else if (st == STATE::kFree) {
val = BitAnd(val, ~(static_cast<uint64_t>(ALLOC_BIT)));
} else if (st == STATE::kEmpty) {
val = 0;
}
clr = static_cast<int>(~(BitRightShift(mask, shift)));
map_[pos] = static_cast<char>(BitAnd(map_[pos], clr));
map_[pos] = static_cast<char>(BitOr(map_[pos], BitRightShift(val, shift)));
if (st == STATE::kAlloc) {
count_[log_sz]--;
} else if (st == STATE::kFree) {
count_[log_sz]++;
if (rel_addr < hint_[log_sz]) {
hint_[log_sz] = rel_addr;
}
}
}
void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) {
while (blk_sz < BitLeftShift(1, num_lvl_)) {
rel_addr_t buddy = BitEx(addr, blk_sz);
size_t sz = 0;
STATE st;
GetBuddySegState(buddy, &sz, &st);
if (st == STATE::kFree && sz == blk_sz) {
auto log_sz = static_cast<log_t>(Log2(blk_sz));
rel_addr_t left = (buddy < addr) ? buddy : addr;
rel_addr_t right = left + blk_sz;
DS_ASSERT(count_[log_sz] >= 2);
count_[log_sz] -= 2;
SetBuddySegState(right, blk_sz, STATE::kEmpty);
SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree);
for (int i = 0; i < log_sz; i++) {
if (hint_[i] == right) {
hint_[i] = left;
}
}
addr = left;
blk_sz <<= 1u;
} else {
break;
}
}
}
void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) {
DS_ASSERT(ask_sz < blk_sz);
uint32_t inx = Log2(blk_sz);
size_t remaining_sz = ask_sz;
for (int i = inx; i > 0; i--) {
size_t b_size = BitLeftShift(1, i);
size_t half_sz = BitRightShift(b_size, 1);
count_[i]--;
SetBuddySegState(addr, half_sz, STATE::kFree);
SetBuddySegState(addr + half_sz, half_sz, STATE::kFree);
if (remaining_sz >= half_sz) {
SetBuddySegState(addr, half_sz, STATE::kAlloc);
remaining_sz -= half_sz;
if (remaining_sz == 0) {
break;
}
addr += half_sz;
}
}
}
void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) {
DS_ASSERT(ask_sz < blk_sz);
uint32_t inx = Log2(blk_sz);
size_t remaining_sz = ask_sz;
for (int i = inx; i > 0; i--) {
size_t b_size = BitLeftShift(1, i);
size_t half_sz = BitRightShift(b_size, 1);
if (remaining_sz >= half_sz) {
#ifdef DEBUG
{
size_t sz = 0;
STATE st;
GetBuddySegState(addr, &sz, &st);
DS_ASSERT(sz == half_sz && st == STATE::kAlloc);
}
#endif
SetBuddySegState(addr, half_sz, STATE::kFree);
remaining_sz -= half_sz;
if (remaining_sz == 0) {
JoinBuddySeg(addr, half_sz);
break;
}
addr += half_sz;
}
}
}
rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept {
uint32_t blk_size = NextPowerOf2(req_size);
int start_inx = static_cast<int>(Log2(blk_size));
bool found = false;
rel_addr_t ask_addr = 0;
auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, num_lvl_ - 1));
STATE st;
size_t sz = 0;
for (int i = start_inx; !found && i < num_lvl_; i++) {
DS_ASSERT(count_[i] >= 0);
if (count_[i] == 0) {
continue;
}
auto blk_sz = static_cast<size_t>(BitLeftShift(1, i));
ask_addr = hint_[i];
while (ask_addr < max_addr && !found) {
GetBuddySegState(ask_addr, &sz, &st);
if (st == STATE::kFree && sz == blk_sz) {
found = true;
} else {
DS_ASSERT(st != STATE::kEmpty);
ask_addr += ((sz > blk_sz) ? sz : blk_sz);
}
}
}
if (found) {
if (sz > req_size) {
TrimBuddySeg(ask_addr, sz, req_size);
} else {
SetBuddySegState(ask_addr, sz, STATE::kAlloc);
hint_[start_inx] = ask_addr;
}
return ask_addr;
} else {
return static_cast<rel_addr_t>(NOSPACE);
}
}
void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) {
if (req_size == blk_size) {
#ifdef DEBUG
{
size_t sz = 0;
STATE st;
GetBuddySegState(addr, &sz, &st);
}
#endif
SetBuddySegState(addr, blk_size, STATE::kFree);
JoinBuddySeg(addr, blk_size);
} else {
UnTrimBuddySeg(addr, blk_size, req_size);
}
}
int BuddySpace::PercentFree() const {
uint64_t total_free_sz = 0;
uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1);
// Go through the count array without lock
for (int i = 0; i < num_lvl_; i++) {
int cnt = count_[i];
if (cnt == 0) {
continue;
}
uint64_t blk_sz = BitLeftShift(1, i);
total_free_sz += (blk_sz * cnt);
}
return static_cast<int>(static_cast<float>(total_free_sz) / static_cast<float>(max_sz_in_unit) * 100);
}
BuddySpace::BuddySpace(int log_min, int num_lvl)
: hint_(nullptr),
count_(nullptr),
map_(nullptr),
log_min_(log_min),
num_lvl_(num_lvl),
min_(0),
max_(0),
ptr_(nullptr) {}
BuddySpace::~BuddySpace() {
if (ptr_ != nullptr) {
free(ptr_);
}
hint_ = nullptr;
count_ = nullptr;
map_ = nullptr;
}
Status BuddySpace::CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min, int num_lvl) {
Status rc;
auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl);
if (bs == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
rc = bs->Init();
if (rc.IsOk()) {
(*out_bs).reset(bs);
} else {
delete bs;
}
return rc;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_BUDDY_H_
#define DATASET_UTIL_BUDDY_H_
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
#include "dataset/util/status.h"
using addr_t = int64_t;
using rel_addr_t = int32_t;
using log_t = int;
#define ALLOC_BIT 0x80
#define ONE_BIT 0x40
#define TWO_BIT 0x20
#define MORE_BIT 0x10
#define NOSPACE ((addr_t)(-1))
namespace mindspore {
namespace dataset {
struct BSpaceDescriptor {
int32_t sig;
rel_addr_t addr;
size_t req_size;
size_t blk_size;
};
class BuddySpace {
public:
// C++11 feature. Change STATE into a type safe class with
// the keyword. Don't take out the keyword 'class'
enum class STATE { kFree, kAlloc, kEmpty };
BuddySpace(const BuddySpace &) = delete;
BuddySpace &operator=(const BuddySpace &) = delete;
virtual ~BuddySpace();
Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept;
void Free(const BSpaceDescriptor *desc);
uint64_t GetMinSize() const { return min_; }
uint64_t GetMaxSize() const { return max_; }
int PercentFree() const;
friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s);
static uint64_t NextPowerOf2(uint64_t n) {
if (n <= 1) {
return 1;
}
n = n - 1;
while (n & (n - 1)) {
n = n & (n - 1);
}
return n << 1;
}
static uint32_t Log2(uint64_t n) {
uint32_t cnt = 0;
while (n >>= 1) {
cnt++;
}
return cnt;
}
static Status CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min = 15, int num_lvl = 18);
private:
rel_addr_t *hint_;
int *count_;
char *map_;
int log_min_;
int num_lvl_;
uint64_t min_;
uint64_t max_;
void *ptr_;
std::mutex mutex_;
explicit BuddySpace(int log_min = 15, int num_lvl = 18);
Status Init();
addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept;
void FreeNoLock(const BSpaceDescriptor *desc);
uint32_t SizeToBlock(const uint64_t sz) const {
uint32_t reqSize = (sz / min_);
if (sz % min_) {
reqSize++;
}
return reqSize;
}
void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const;
void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st);
void JoinBuddySeg(rel_addr_t addr, size_t blk_sz);
void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz);
void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz);
rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept;
void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size);
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_BUDDY_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "common/utils.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/services.h"
namespace mindspore {
namespace dataset {
CachePool::CachePool(const value_allocator &alloc, const std::string &root)
: alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}
Status CachePool::DoServiceStart() {
tree_ = std::make_shared<data_index>();
// If we are given a disk path, set up the StorageManager
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
RETURN_IF_NOT_OK(spill.CreateDirectories());
sm_ = std::make_shared<StorageManager>(spill);
RETURN_IF_NOT_OK(sm_->ServiceStart());
MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString());
}
return Status::OK();
}
Status CachePool::DoServiceStop() {
Status rc;
Status rc2;
if (sm_ != nullptr) {
rc = sm_->ServiceStop();
if (rc.IsError()) {
rc2 = rc;
}
}
sm_.reset();
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
}
}
tree_.reset();
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
auto it = Path::DirIterator::OpenDirectory(&spill);
while (it->hasNext()) {
rc = it->next().Remove();
if (rc.IsError() && rc2.IsOk()) {
rc2 = rc;
}
}
rc = spill.Remove();
if (rc.IsError() && rc2.IsOk()) {
rc2 = rc;
}
}
return rc2;
}
CachePool::~CachePool() noexcept { (void)ServiceStop(); }
Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) {
DataLocator bl;
Status rc;
size_t sz = 0;
// We will consolidate all the slices into one piece.
for (auto &v : buf) {
sz += v.GetSize();
}
bl.sz = sz;
try {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
}
pos += v.GetSize();
}
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
}
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
// We have an assumption 0 is not a valid key from the design of AutoIndexObj.
// Make sure it is not 0.
if (bl.storage_key == 0) {
RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected");
}
} else {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
rc = tree_->insert(bl, key);
if (rc.IsError() && bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, sz);
}
return rc;
}
Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
if (it->ptr != nullptr) {
ReadableSlice src(it->ptr, it->sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src));
} else if (sm_ != nullptr) {
size_t expectedLength = 0;
RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength));
if (expectedLength != it->sz) {
MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "."
<< " Internal key: " << key << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
}
if (bytesRead != nullptr) {
*bytesRead = it->sz;
}
} else {
RETURN_STATUS_UNEXPECTED("Key not found");
}
return Status::OK();
}
const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; }
Path CachePool::GetSpillPath() const {
auto spill = Path(root_) / subfolder_;
return spill;
}
CachePool::CacheStat CachePool::GetStat() const {
CacheStat cs{0};
for (auto &it : *tree_) {
if (it.ptr != nullptr) {
++cs.num_mem_cached;
} else {
++cs.num_disk_cached;
}
}
return cs;
}
Status CachePool::Spill(CachePool::DataLocator *dl) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to spill");
}
RETURN_UNEXPECTED_IF_NULL(dl);
RETURN_UNEXPECTED_IF_NULL(dl->ptr);
if (dl->storage_key == 0) {
ReadableSlice data(dl->ptr, dl->sz);
RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data}));
}
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return Status::OK();
}
Status CachePool::Locate(CachePool::DataLocator *dl) {
RETURN_UNEXPECTED_IF_NULL(dl);
if (dl->ptr == nullptr) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to locate the data");
}
try {
dl->ptr = alloc_.allocate(dl->sz);
WritableSlice dest(dl->ptr, dl->sz);
Status rc = Read(dl->storage_key, &dest);
if (rc.IsError()) {
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return rc;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
return Status::OK();
}
size_t CachePool::GetSize(CachePool::key_type key) const {
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
return it->sz;
} else {
return 0;
}
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_CACHE_POOL_H_
#define DATASET_UTIL_CACHE_POOL_H_
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_manager.h"
#include "dataset/util/auto_index.h"
namespace mindspore {
namespace dataset {
/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of
/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to
/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to
/// restore the buffer.
/// \see ReadableSlice
class CachePool : public Service {
public:
using base_type = uint8_t;
using pointer = base_type *;
using const_pointer = const base_type *;
using reference = base_type &;
using const_reference = const base_type &;
using value_allocator = Allocator<base_type>;
// An internal class to locate the whereabouts of a backed up buffer which can be either in
class DataLocator {
public:
DataLocator() : ptr(nullptr), sz(0), storage_key(0) {}
~DataLocator() = default;
DataLocator(const DataLocator &other) = default;
DataLocator &operator=(const DataLocator &other) = default;
DataLocator(DataLocator &&other) noexcept {
ptr = other.ptr;
sz = other.sz;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
other.storage_key = 0;
}
DataLocator &operator=(DataLocator &&other) noexcept {
if (&other != this) {
ptr = other.ptr;
sz = other.sz;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
other.storage_key = 0;
}
return *this;
}
pointer ptr;
size_t sz;
StorageManager::key_type storage_key;
};
using data_index = AutoIndexObj<DataLocator>;
using key_type = data_index::key_type;
using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other;
/// \brief Simple statistics returned from CachePool like how many elements are cached in memory and
/// how many elements are spilled to disk.
struct CacheStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
};
/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit CachePool(const value_allocator &alloc, const std::string &root = "");
CachePool(const CachePool &) = delete;
CachePool(CachePool &&) = delete;
CachePool &operator=(const CachePool &) = delete;
CachePool &operator=(CachePool &&) = delete;
~CachePool() noexcept;
Status DoServiceStart() override;
Status DoServiceStop() override;
Path GetSpillPath() const;
/// \brief Insert a sequence of ReadableSlice objects into the pool.
/// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk.
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[out] key Generated key
/// \return Error code
Status Insert(const std::vector<ReadableSlice> &buf, key_type *key);
/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
/// \param[out] bytesRead Optional. Number of bytes read.
/// \return Error code
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const;
Status Spill(DataLocator *dl);
Status Locate(DataLocator *dl);
size_t GetSize(key_type key) const;
/// \brief Get statistics.
/// \return CacheStat object
CacheStat GetStat() const;
const value_allocator &get_allocator() const;
std::string MyName() const { return subfolder_; }
private:
value_allocator alloc_;
Path root_;
const std::string subfolder_;
std::shared_ptr<StorageManager> sm_;
std::shared_ptr<data_index> tree_;
};
} // namespace dataset
} // namespace mindspore
#endif
...@@ -106,6 +106,24 @@ struct List { ...@@ -106,6 +106,24 @@ struct List {
++count; ++count;
} }
// Insert elem2 before elem1 in the list.
virtual void InsertBefore(pointer elem1, pointer elem2) {
DS_ASSERT(elem1 != elem2);
Node<T> &elem1_node = elem1->*node;
Node<T> &elem2_node = elem2->*node;
elem2_node.next = elem1;
elem2_node.prev = elem1_node.prev;
if (elem1_node.prev != nullptr) {
Node<T> &prev_node = elem1_node.prev->*node;
prev_node.next = elem2;
}
elem1_node.prev = elem2;
if (head == elem1) {
head = elem2;
}
++count;
}
// Remove an element in the list // Remove an element in the list
virtual void Remove(pointer elem) noexcept { virtual void Remove(pointer elem) noexcept {
Node<T> &elem_node = elem->*node; Node<T> &elem_node = elem->*node;
......
...@@ -44,20 +44,6 @@ class MemoryPool { ...@@ -44,20 +44,6 @@ class MemoryPool {
virtual ~MemoryPool() {} virtual ~MemoryPool() {}
}; };
// Used by unique_ptr
template <typename T>
class Deleter {
public:
explicit Deleter(std::shared_ptr<MemoryPool> &mp) : mp_(mp) {}
~Deleter() = default;
void operator()(T *ptr) const { mp_->Deallocate(ptr); }
private:
std::shared_ptr<MemoryPool> mp_;
};
Status DeMalloc(std::size_t s, void **p, bool); Status DeMalloc(std::size_t s, void **p, bool);
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "dataset/util/path.h" #include "dataset/util/path.h"
#include <sys/stat.h> #include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <new> #include <new>
#include <sstream> #include <sstream>
#include <utility> #include <utility>
...@@ -26,7 +28,7 @@ ...@@ -26,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
#ifdef _WIN32 #if defined(_WIN32) || defined(_WIN64)
char Path::separator_ = '\\'; char Path::separator_ = '\\';
#else #else
char Path::separator_ = '/'; char Path::separator_ = '/';
...@@ -132,7 +134,7 @@ Status Path::CreateDirectory() { ...@@ -132,7 +134,7 @@ Status Path::CreateDirectory() {
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
int rc = mkdir(common::SafeCStr(path_)); int rc = mkdir(common::SafeCStr(path_));
#else #else
int rc = mkdir(common::SafeCStr(path_), 0700); int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR);
#endif #endif
if (rc) { if (rc) {
std::ostringstream oss; std::ostringstream oss;
...@@ -182,6 +184,111 @@ Status Path::CreateDirectories() { ...@@ -182,6 +184,111 @@ Status Path::CreateDirectories() {
return Status::OK(); return Status::OK();
} }
Status Path::Remove() {
if (Exists()) {
if (IsDirectory()) {
errno_t err = rmdir(common::SafeCStr(path_));
if (err == -1) {
std::ostringstream oss;
oss << "Unable to delete directory " << path_ << ". Errno = " << errno;
RETURN_STATUS_UNEXPECTED(oss.str());
}
} else {
errno_t err = unlink(common::SafeCStr(path_));
if (err == -1) {
std::ostringstream oss;
oss << "Unable to delete file " << path_ << ". Errno = " << errno;
RETURN_STATUS_UNEXPECTED(oss.str());
}
}
}
return Status::OK();
}
Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); }
Status Path::OpenFile(int *file_descriptor, bool create) {
int fd;
if (file_descriptor == nullptr) {
RETURN_STATUS_UNEXPECTED("null pointer");
}
if (IsDirectory()) {
std::ostringstream oss;
oss << "Unable to create file " << path_ << " which is a directory.";
RETURN_STATUS_UNEXPECTED(oss.str());
}
// Convert to canonical form.
if (strlen(common::SafeCStr(path_)) > PATH_MAX) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
char canonical_path[PATH_MAX + 1] = {0x00};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) {
#else
if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) {
#endif
if (errno == ENOENT && create) {
// File doesn't exist and we are to create it. Let's break it down.
auto file_part = Basename();
auto parent_part = ParentPath();
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) {
#else
if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) {
#endif
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto cur_inx = strlen(canonical_path);
if ((cur_inx + file_part.length() + 1) > PATH_MAX) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
canonical_path[cur_inx++] = separator_;
if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) !=
EOK) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
} else {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
}
if (create) {
fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP);
} else {
fd = open(canonical_path, O_RDWR);
}
if (fd == -1) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
*file_descriptor = fd;
return Status::OK();
}
Status Path::CloseFile(int fd) const {
if (close(fd) < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
return Status::OK();
}
Status Path::TruncateFile(int fd) const {
int rc;
rc = ftruncate(fd, 0);
if (rc == 0) {
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
}
std::string Path::Basename() {
std::size_t found = path_.find_last_of(separator_);
if (found != std::string::npos) {
return path_.substr(found + 1);
} else {
return path_;
}
}
std::shared_ptr<Path::DirIterator> Path::DirIterator::OpenDirectory(Path *f) { std::shared_ptr<Path::DirIterator> Path::DirIterator::OpenDirectory(Path *f) {
auto it = new (std::nothrow) DirIterator(f); auto it = new (std::nothrow) DirIterator(f);
...@@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() { ...@@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() {
Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) {
MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; MS_LOG(DEBUG) << "Open directory " << f->toString() << ".";
dp_ = opendir(common::SafeCStr(f->toString())); dp_ = opendir(f->toString().c_str());
} }
bool Path::DirIterator::hasNext() { bool Path::DirIterator::hasNext() {
...@@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() { ...@@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() {
} }
Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); }
std::ostream &operator<<(std::ostream &os, const Path &s) {
os << s.path_;
return os;
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
...@@ -90,6 +90,20 @@ class Path { ...@@ -90,6 +90,20 @@ class Path {
std::string ParentPath(); std::string ParentPath();
Status Remove();
Status CreateFile(int *fd);
Status OpenFile(int *fd, bool create = false);
Status CloseFile(int fd) const;
Status TruncateFile(int fd) const;
std::string Basename();
friend std::ostream &operator<<(std::ostream &os, const Path &s);
private: private:
static char separator_; static char separator_;
std::string path_; std::string path_;
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/semaphore.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
Status Semaphore::P() {
std::unique_lock<std::mutex> lck(mutex_);
RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; }));
--value_;
return Status::OK();
}
void Semaphore::V() {
std::unique_lock<std::mutex> lck(mutex_);
++value_;
wait_cond_.NotifyOne();
}
int Semaphore::Peek() {
std::unique_lock<std::mutex> lck(mutex_);
return value_;
}
Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); }
Status Semaphore::Deregister() { return (wait_cond_.Deregister()); }
void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); }
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SEMAPHORE_H_
#define DATASET_UTIL_SEMAPHORE_H_
#include "dataset/util/cond_var.h"
namespace mindspore {
namespace dataset {
class TaskGroup;
/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be
/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters.
class Semaphore {
public:
/// \brief Constructor
/// \param init Initial value of the internal counter.
explicit Semaphore(int init) : value_(init) {}
virtual ~Semaphore() {}
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status P();
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
void V();
/// \brief Peek the internal value
/// \return The internal value
int Peek();
Status Register(TaskGroup *vg);
Status Deregister();
void ResetIntrpState();
private:
int value_;
std::mutex mutex_;
CondVar wait_cond_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_SEMAPHORE_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/slice.h"
namespace mindspore {
namespace dataset {
WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) {
mutable_data_ = static_cast<char *>(src.mutable_data_) + offset;
}
WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset)
: WritableSlice(src, offset, src.GetSize() - offset) {}
Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) {
RETURN_UNEXPECTED_IF_NULL(dest);
RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer());
if (dest->GetSize() <= 0) {
RETURN_STATUS_UNEXPECTED("Destination length is non-positive");
}
auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize());
if (err) {
RETURN_STATUS_UNEXPECTED(std::to_string(err));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SLICE_H_
#define DATASET_UTIL_SLICE_H_
#include <unistd.h>
#include <cstddef>
#include <utility>
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief A ReadableSlice wraps a const pointer in memory and its size.
/// \see WritableSlice for a non-const version
///
class ReadableSlice {
public:
ReadableSlice() : ptr_(nullptr), sz_(0) {}
ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {}
ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) {
ptr_ = static_cast<const char *>(src.GetPointer()) + offset;
sz_ = len;
}
ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {}
ReadableSlice(const ReadableSlice &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
}
ReadableSlice &operator=(const ReadableSlice &lhs) {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
}
return *this;
}
ReadableSlice(ReadableSlice &&lhs) noexcept {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
lhs.ptr_ = nullptr;
lhs.sz_ = 0;
}
}
ReadableSlice &operator=(ReadableSlice &&lhs) noexcept {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
lhs.ptr_ = nullptr;
lhs.sz_ = 0;
}
return *this;
}
/// \brief Getter function
/// \return Const version of the pointer
const void *GetPointer() const { return ptr_; }
/// \brief Getter function
/// \return Size of the slice
size_t GetSize() const { return sz_; }
bool empty() const { return ptr_ == nullptr; }
private:
const void *ptr_;
size_t sz_;
};
/// \brief A WritableSlice inherits from ReadableSlice to allow
/// one to write to the address pointed to by the pointer.
///
class WritableSlice : public ReadableSlice {
public:
friend class StorageContainer;
/// \brief Default constructor
WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {}
/// \brief This form of a constructor takes a pointer and its size.
WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {}
WritableSlice(const WritableSlice &src, off64_t offset, size_t len);
WritableSlice(const WritableSlice &src, off64_t offset);
WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; }
WritableSlice &operator=(const WritableSlice &lhs) {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
ReadableSlice::operator=(lhs);
}
return *this;
}
WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
lhs.mutable_data_ = nullptr;
}
}
WritableSlice &operator=(WritableSlice &&lhs) noexcept {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
lhs.mutable_data_ = nullptr;
ReadableSlice::operator=(std::move(lhs));
}
return *this;
}
/// \brief Copy the content from one slice onto another.
static Status Copy(WritableSlice *dest, const ReadableSlice &src);
private:
void *mutable_data_;
void *GetMutablePointer() { return mutable_data_; }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_SLICE_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_container.h"
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "common/utils.h"
#include "dataset/util/de_error.h"
#include "dataset/util/path.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
Status StorageContainer::Create() {
RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_));
RETURN_IF_NOT_OK(cont_.CreateFile(&fd_));
is_open_ = true;
MS_LOG(INFO) << "Container " << cont_ << " created";
return Status::OK();
}
Status StorageContainer::Open() noexcept {
std::lock_guard<std::mutex> lck(mutex_);
// Check again
if (!is_open_) {
RETURN_IF_NOT_OK(cont_.OpenFile(&fd_));
is_open_ = true;
}
return Status::OK();
}
Status StorageContainer::Close() noexcept {
if (is_open_) {
std::lock_guard<std::mutex> lck(mutex_);
// Check again
if (is_open_) {
RETURN_IF_NOT_OK(cont_.CloseFile(fd_));
is_open_ = false;
fd_ = -1;
}
}
return Status::OK();
}
Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept {
DS_ASSERT(is_open_);
RETURN_UNEXPECTED_IF_NULL(dest);
auto sz = dest->GetSize();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pread64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std::lock_guard<std::mutex> lck(mutex_);
auto seek_err = lseek(fd_, offset, SEEK_SET);
if (seek_err < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto r_sz = read(fd_, dest->GetMutablePointer(), sz);
#else
auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset);
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
}
return Status::OK();
}
Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept {
DS_ASSERT(is_open_);
auto sz = dest.GetSize();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pwrite64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std::lock_guard<std::mutex> lck(mutex_);
auto seek_err = lseek(fd_, offset, SEEK_SET);
if (seek_err < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto r_sz = write(fd_, dest.GetPointer(), sz);
#else
auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset);
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
}
return Status::OK();
}
Status StorageContainer::Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept {
size_t sz = 0;
for (auto &v : buf) {
sz += v.GetSize();
}
if (sz == 0) {
RETURN_STATUS_UNEXPECTED("Unexpected 0 length");
}
if (sz > bs_->GetMaxSize()) {
RETURN_STATUS_UNEXPECTED("Request size too big");
}
BSpaceDescriptor bspd{0};
addr_t addr = 0;
RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr));
*offset = static_cast<off64_t>(addr);
// We will do piecewise copy of the data to disk.
for (auto &v : buf) {
RETURN_IF_NOT_OK(Write(v, addr));
addr += v.GetSize();
}
return Status::OK();
}
Status StorageContainer::Truncate() const noexcept {
if (is_open_) {
RETURN_IF_NOT_OK(cont_.TruncateFile(fd_));
MS_LOG(INFO) << "Container " << cont_ << " truncated";
}
return Status::OK();
}
StorageContainer::~StorageContainer() noexcept {
(void)Truncate();
(void)Close();
}
std::ostream &operator<<(std::ostream &os, const StorageContainer &s) {
os << "File path : " << s.cont_ << "\n" << *(s.bs_.get());
return os;
}
Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path) {
Status rc;
auto sc = new (std::nothrow) StorageContainer(path);
if (sc == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
rc = sc->Create();
if (rc.IsOk()) {
(*out_sc).reset(sc);
} else {
delete sc;
}
return rc;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_
#define DATASET_UTIL_STORAGE_CONTAINER_H_
#include <limits.h>
#include <unistd.h>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/system_pool.h"
#include "dataset/util/buddy.h"
#include "dataset/util/path.h"
#include "dataset/util/slice.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class StorageManager;
class StorageContainer {
public:
friend class StorageManager;
~StorageContainer() noexcept;
StorageContainer(const StorageContainer &) = delete;
StorageContainer &operator=(const StorageContainer &) = delete;
friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s);
Status Open() noexcept;
Status Close() noexcept;
Status Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept;
Status Write(const ReadableSlice &dest, off64_t offset) const noexcept;
Status Read(WritableSlice *dest, off64_t offset) const noexcept;
Status Truncate() const noexcept;
bool IsOpen() const { return is_open_; }
static Status CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path);
private:
mutable std::mutex mutex_;
Path cont_;
int fd_;
bool is_open_;
std::unique_ptr<BuddySpace> bs_;
// Use the default value of BuddySpace
// which can map upto 4G of space.
explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {}
Status Create();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_STORAGE_CONTAINER_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_manager.h"
#include <iomanip>
#include <sstream>
#include <stdexcept>
#include <utility>
#include "common/utils.h"
#include "dataset/util/path.h"
#include "dataset/util/services.h"
#include "dataset/util//de_error.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) {
std::ostringstream oss;
oss << prefix << std::setfill('0') << std::setw(5) << file_id;
return oss.str();
}
std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) {
std::string base_name = GetBaseName(prefix, file_id);
return (base_name + "." + suffix);
}
Status StorageManager::AddOneContainer() {
const std::string kPrefix = "IMG";
const std::string kSuffix = "LB";
Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix);
std::shared_ptr<StorageContainer> sc;
RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString()));
containers_.push_back(sc);
file_id_++;
return Status::OK();
}
Status StorageManager::DoServiceStart() {
containers_.reserve(1000);
if (root_.IsDirectory()) {
RETURN_IF_NOT_OK(AddOneContainer());
} else {
RETURN_STATUS_UNEXPECTED("Not a directory");
}
return Status::OK();
}
Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &buf) {
RETURN_UNEXPECTED_IF_NULL(key);
size_t sz = 0;
for (auto &v : buf) {
sz += v.GetSize();
}
if (sz == 0) {
RETURN_STATUS_UNEXPECTED("Unexpected 0 length");
}
std::shared_ptr<StorageContainer> cont;
key_type out_key;
value_type out_value;
bool create_new_container = false;
do {
SharedLock lock_s(&rw_lock_);
size_t num_containers = containers_.size();
if (create_new_container) {
// Upgrade to exclusvie lock.
lock_s.Upgrade();
create_new_container = false;
// Check again if someone has already added a
// new container after we got the x lock
if (containers_.size() == num_containers) {
RETURN_IF_NOT_OK(AddOneContainer());
}
// Refresh how many containers there are.
num_containers = containers_.size();
// Downgrade back to shared lock
lock_s.Downgrade();
}
if (num_containers == 0) {
RETURN_STATUS_UNEXPECTED("num_containers is zero");
}
// Go to the last container to insert.
cont = containers_.at(num_containers - 1);
off64_t offset;
Status rc = cont->Insert(buf, &offset);
if (rc.IsNoSpace()) {
create_new_container = true;
} else if (rc.IsOk()) {
out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz));
RETURN_IF_NOT_OK(index_.insert(out_value, &out_key));
*key = out_key;
break;
} else {
return rc;
}
} while (true);
return Status::OK();
}
Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = index_.Search(key);
if (r.second) {
auto &it = r.first;
value_type v = *it;
int container_inx = v.first;
off_t offset = v.second.first;
size_t sz = v.second.second;
if (dest->GetSize() < sz) {
std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) +
" but length = " + std::to_string(dest->GetSize());
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (bytesRead != nullptr) {
*bytesRead = sz;
}
auto cont = containers_.at(container_inx);
RETURN_IF_NOT_OK(cont->Read(dest, offset));
} else {
RETURN_STATUS_UNEXPECTED("Key not found");
}
return Status::OK();
}
Status StorageManager::DoServiceStop() noexcept {
Status rc;
Status rc1;
for (auto const &p : containers_) {
// The destructor of StorageContainer is not called automatically until the use
// count drops to 0. But it is not always the case. We will do it ourselves.
rc = p.get()->Truncate();
if (rc.IsError()) {
rc1 = rc;
}
}
containers_.clear();
file_id_ = 0;
return rc1;
}
StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {}
StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); }
std::ostream &operator<<(std::ostream &os, const StorageManager &s) {
os << "Dumping all containers ..."
<< "\n";
for (auto const &p : s.containers_) {
os << *(p.get());
}
return os;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_MANAGER_H_
#define DATASET_UTIL_STORAGE_MANAGER_H_
#include <unistd.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/auto_index.h"
#include "dataset/util/lock.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/path.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_container.h"
using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>;
namespace mindspore {
namespace dataset {
class StorageManager : public Service {
public:
using storage_index = AutoIndexObj<std::pair<int, std::pair<off_t, size_t>>>;
using key_type = storage_index::key_type;
using value_type = storage_index::value_type;
explicit StorageManager(const Path &);
~StorageManager() override;
StorageManager(const StorageManager &) = delete;
StorageManager &operator=(const StorageManager &) = delete;
Status Write(key_type *out_key, const std::vector<ReadableSlice> &buf);
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const;
Status DoServiceStart() override;
Status DoServiceStop() noexcept override;
friend std::ostream &operator<<(std::ostream &os, const StorageManager &s);
private:
Path root_;
ListOfContainers containers_;
int file_id_;
RWLock rw_lock_;
storage_index index_;
std::string GetBaseName(const std::string &prefix, int32_t file_id);
std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix);
Status AddOneContainer();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_STORAGE_MANAGER_H_
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
#include <cstddef> #include <cstddef>
#include <cstdlib> #include <cstdlib>
#include <limits> #include <limits>
#include <memory>
#include <new> #include <new>
#include "./securec.h" #include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/memory_pool.h" #include "dataset/util/memory_pool.h"
namespace mindspore { namespace mindspore {
...@@ -61,6 +63,11 @@ class SystemPool : public MemoryPool { ...@@ -61,6 +63,11 @@ class SystemPool : public MemoryPool {
uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); } uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); }
int PercentFree() const override { return 100; } int PercentFree() const override { return 100; }
template <typename T>
static Allocator<T> GetAllocator() {
return Allocator<T>(std::make_shared<SystemPool>());
}
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "kernel/common_utils.h" #include "kernel/common_utils.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "ir/value.h" #include "ir/value.h"
#include "pre_activate/common/helper.h"
using mindspore::kernel::Address; using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
...@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { ...@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
} }
} }
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
AddressPtrList *kernel_outputs) { AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
...@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod ...@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs); return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
} }
auto is_all_nop_node = opt::IsAllNopNode(&graph);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
} else {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
}
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>(); kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
...@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod ...@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_inputs->emplace_back(input); kernel_inputs->emplace_back(input);
} }
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
auto device_address = AnfAlgo::GetOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
} else {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr output = std::make_shared<kernel::Address>(); kernel::AddressPtr output = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
output->addr = device_address->ptr_; output->addr = device_address->ptr_;
...@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod ...@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_outputs->emplace_back(output); kernel_outputs->emplace_back(output);
} }
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace); MS_EXCEPTION_IF_NULL(workspace);
...@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { ...@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_inputs; AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces; AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs; AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed."; MS_LOG(ERROR) << "Launch kernel failed.";
......
...@@ -96,8 +96,8 @@ class KernelRuntime { ...@@ -96,8 +96,8 @@ class KernelRuntime {
private: private:
void AssignStaticMemoryOutput(const session::KernelGraph *graph); void AssignStaticMemoryOutput(const session::KernelGraph *graph);
void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
bool LaunchKernelMod(const session::KernelGraph &graph); bool LaunchKernelMod(const session::KernelGraph &graph);
void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);
......
...@@ -17,13 +17,23 @@ ...@@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h" #include "ir/anf.h"
#include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
} // namespace opt
class OptimizerCaller { class OptimizerCaller {
public: public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
}; };
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "kernel/akg/akg_kernel_metadata.h" #include "kernel/akg/akg_kernel_metadata.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
...@@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel ...@@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
std::string op_name = AnfAlgo::GetCNodeName(kernel_node); std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
kernel_type = KernelType::AKG_KERNEL;
}
switch (kernel_type) { switch (kernel_type) {
case KernelType::AKG_KERNEL: case KernelType::AKG_KERNEL:
AkgMetadataInfo(kernel_node, kernel_info_list); AkgMetadataInfo(kernel_node, kernel_info_list);
......
...@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { ...@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed; return changed;
} }
// The op like print, summary, or the op do not has true output, and always as a depend node input.
static bool HasSideEffect(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) {
return false;
}
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
return GetValue<bool>(side_effect_v);
}
return false;
}
// If true do not merge the node.
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false; bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main); auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node); auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) { // if has random effect, when generate by different op (not same object), do not merge.
return false;
}
if (prim_main != nullptr) { if (prim_main != nullptr) {
if (prim_main == prim_node) {
return false;
}
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) { if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val); has_random_effect = GetValue<bool>(effect_val);
...@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons ...@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
return has_random_effect; return has_random_effect;
} }
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
bool replace = false;
if (main->isa<ValueNode>() && node->isa<ValueNode>()) { if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main); auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node); auto node_value = GetValueNode(node);
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
} else if (main->isa<CNode>() && node->isa<CNode>()) { } else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>(); auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>(); auto c_node = node->cast<CNodePtr>();
// When appsame is true, check if has side effect, do not merge.
if (check_side_effect && HasSideEffect(main)) {
return false;
}
const auto &inp1 = c_main->inputs(); const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs(); const auto &inp2 = c_node->inputs();
if (inp1.size() == inp2.size()) { if (inp1.size() != inp2.size()) {
bool appsame = true; return false;
for (size_t j = 0; j < inp1.size(); j++) { }
MS_EXCEPTION_IF_NULL(inp1[j]); for (size_t j = 0; j < inp1.size(); j++) {
MS_EXCEPTION_IF_NULL(inp2[j]); auto inp1_j = inp1[j];
if (!(*inp1[j] == *inp2[j])) { auto inp2_j = inp2[j];
// Handle the case of two different Tensor, but with the same value MS_EXCEPTION_IF_NULL(inp1_j);
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) { MS_EXCEPTION_IF_NULL(inp2_j);
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]); if (!(*inp1_j == *inp2_j)) {
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]); // Handle the case of two different Tensor, but with the same value
if (tensor1->ValueEqual(*tensor2)) { if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
continue; auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
} auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
// When the same side effect node as another two nodes' inputs, we still merge the node.
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
// node.
if (CheckReplace(inp1_j, inp2_j, false)) {
continue;
} }
appsame = false;
break;
} }
return false;
} }
if (CheckRandomEffect(c_main, c_node)) {
appsame = false;
}
replace = appsame;
} }
// When appsame is true, check if has random effect do not merge
if (CheckRandomEffect(c_main, c_node)) {
return false;
}
return true;
} }
return replace; // a parameter node.
return false;
} }
bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group, bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
......
...@@ -41,7 +41,7 @@ class CSE { ...@@ -41,7 +41,7 @@ class CSE {
return chg && report_changes_; return chg && report_changes_;
} }
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const;
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
......
...@@ -14,140 +14,154 @@ ...@@ -14,140 +14,154 @@
* limitations under the License. * limitations under the License.
*/ */
#include "optimizer/irpass.h"
#include <string> #include <string>
#include "optimizer/irpass/symbol_resolver.h" #include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h" #include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h" #include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/gradient_eliminate.h" #include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h" #include "optimizer/irpass/inline.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/param_replace.h" #include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/mark_interface_fusion.h" #include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h" #include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() { OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ = special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); zero_like_fill_zero_ =
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ =
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate // ops eliminate
item_tuple_eliminate_ = item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); transpose_eliminate_ =
MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
reduce_eliminate_ = MakeSubstitution( reduce_eliminate_ = MakeSubstitution(
ReduceOneEliminater(), "reduce_eliminate", std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); check_bprop_eliminate_ =
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); reset_defer_inline_ =
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
// Env Item Eliminate // Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); env_get_item_eliminate_ =
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ = incorporate_env_getitem_ =
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
// Ref eliminate // Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); make_ref_eliminate_ =
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
replace_refkey_by_param_ = replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
// Gradient transforms // Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
// branch culling // branch culling
switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ = float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
float_env_getitem_switch_ = float_env_getitem_switch_ =
MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); convert_switch_replacement_ =
MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);
// Addn // Addn
merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
// inline // inline
inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>); replace_applicator_ =
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
// Incorporation // Incorporation
incorporate_getitem_set_ = incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); "incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
// Virtual Dataset // Virtual Dataset
virtual_dataset_eliminate_ = virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Convert // Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate // Unused parameter eliminate
unused_parameter_eliminate_ = unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// AddN eliminate // AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
// Mark interface fusion // Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
} }
InferenceOptPrepareLib::InferenceOptPrepareLib() { InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
} }
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#include <vector>
#include <memory>
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h" #include "ir/optimizer_caller.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { ...@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr all_reduce_fg_{nullptr}; FuncGraphPtr all_reduce_fg_{nullptr};
}; };
class ArithmeticSimplify { class ArithmeticSimplify : public OptimizerCaller {
public: public:
ArithmeticSimplify() ArithmeticSimplify()
: multiply_by_zero_or_one_(), : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(), tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(), add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(), tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(prim::kPrimIdentity), identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(), opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(), constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_() { power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_); eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(add_by_zero_);
...@@ -761,10 +762,10 @@ class ArithmeticSimplify { ...@@ -761,10 +762,10 @@ class ArithmeticSimplify {
} }
~ArithmeticSimplify() = default; ~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -773,15 +774,9 @@ class ArithmeticSimplify { ...@@ -773,15 +774,9 @@ class ArithmeticSimplify {
} }
private: private:
MultiplyByZeroOrOne multiply_by_zero_or_one_; OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_,
TensorMultiplyByOne tensor_multiply_by_one_; opt_update_zero_tensor_, constant_duplicate_mul_, power_one_;
AddByZero add_by_zero_; std::vector<OptimizerCallerPtr> eliminaters_{};
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{};
}; };
// Arithmetic Simplifications should be done after step_parallel. // Arithmetic Simplifications should be done after step_parallel.
...@@ -789,15 +784,17 @@ class ArithmeticSimplify { ...@@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the // with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from // shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel. // ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 { class ArithmeticSimplify2 : public OptimizerCaller {
public: public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default; ~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -806,8 +803,8 @@ class ArithmeticSimplify2 { ...@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
} }
private: private:
TensorMultiplyByZero tensor_multiply_by_zero_; OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<OptimizerCallerPtr> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "ir/visitor.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { ...@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, t_{nullptr}; AnfNodePtr x_{nullptr}, t_{nullptr};
}; };
class CastEliminater { class CastEliminater : public OptimizerCaller {
public: public:
CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {}
~CastEliminater() = default; ~CastEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = cast_same_type_eliminater_(optimizer, node); auto new_node = cast_same_type_eliminater_(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
......
...@@ -17,18 +17,19 @@ ...@@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <memory> #include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
namespace mindspore { namespace mindspore {
...@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { ...@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool is_match_{false}; bool is_match_{false};
}; };
class EnvGetItemEliminater { class EnvGetItemEliminater : public OptimizerCaller {
public: public:
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()) {
eliminaters_.emplace_back(new_env_get_item_); eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_); eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_); eliminaters_.emplace_back(env_get_set_item_);
} }
~EnvGetItemEliminater() = default; ~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -246,10 +250,8 @@ class EnvGetItemEliminater { ...@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
} }
private: private:
NewEnvGetItem new_env_get_item_; OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_;
AddEnvGetItem add_env_get_item_; std::vector<OptimizerCallerPtr> eliminaters_{};
EnvGetSetItem env_get_set_item_;
std::vector<TransformFuncType> eliminaters_{};
}; };
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} // {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
......
...@@ -17,18 +17,20 @@ ...@@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#include <vector>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <memory> #include <memory>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
...@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { ...@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };
class IncorporateGetitemSet { class IncorporateGetitemSet : public OptimizerCaller {
public: public:
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { IncorporateGetitemSet()
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
eliminaters_.emplace_back(incorporate_getitem_); eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_); eliminaters_.emplace_back(incorporate_getitem_switch_);
} }
~IncorporateGetitemSet() = default; ~IncorporateGetitemSet() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -403,9 +407,8 @@ class IncorporateGetitemSet { ...@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
} }
private: private:
IncorporateGetitem incorporate_getitem_; OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
IncorporateGetitemSwitch incorporate_getitem_switch_; std::vector<OptimizerCallerPtr> eliminaters_{};
std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#include <vector>
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/irpass.h" #include "ir/optimizer_caller.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { ...@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
}; };
class ItemTupleEliminater { class ItemTupleEliminater : public OptimizerCaller {
public: public:
ItemTupleEliminater() ItemTupleEliminater()
: get_item_eliminater_(), : get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(), get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(), set_item_eliminater_(std::make_shared<SetitemEliminater>()),
get_set_item_eliminater_(), get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
get_item_depend_reorder_() { get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
eliminaters_.emplace_back(get_item_eliminater_); eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_); eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_); eliminaters_.emplace_back(set_item_eliminater_);
...@@ -277,10 +279,10 @@ class ItemTupleEliminater { ...@@ -277,10 +279,10 @@ class ItemTupleEliminater {
} }
~ItemTupleEliminater() = default; ~ItemTupleEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -289,12 +291,9 @@ class ItemTupleEliminater { ...@@ -289,12 +291,9 @@ class ItemTupleEliminater {
} }
private: private:
GetitemEliminater get_item_eliminater_; OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
GetitemConstEliminater get_item_const_eliminater_; get_item_depend_reorder_;
SetitemEliminater set_item_eliminater_; std::vector<OptimizerCallerPtr> eliminaters_{};
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;
std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
#include <memory> #include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h" #include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
......
...@@ -19,11 +19,12 @@ ...@@ -19,11 +19,12 @@
#include <vector> #include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "pipeline/static_analysis/dshape.h" #include "pipeline/static_analysis/dshape.h"
namespace mindspore { namespace mindspore {
...@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { ...@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, shape_{nullptr}; AnfNodePtr x_{nullptr}, shape_{nullptr};
}; };
class ReshapeEliminater { class ReshapeEliminater : public OptimizerCaller {
public: public:
ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {}
~ReshapeEliminater() = default; ~ReshapeEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = reshape_same_shape_eliminater_(optimizer, node); auto new_node = reshape_same_shape_eliminater_(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
......
...@@ -18,31 +18,31 @@ ...@@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#include <securec.h> #include <securec.h>
#include <vector>
#include <memory>
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h" #include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h" #include "ir/pattern_matcher.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "ir/pattern_matcher.h" #include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
class SpecialOpEliminater { class SpecialOpEliminater : public OptimizerCaller {
public: public:
SpecialOpEliminater() SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf), : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
stop_gradient_(prim::kPrimStopGradient), stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
hook_backward_(prim::kPrimHookBackward), hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
print_shape_type_(prim::kPrimPrintShapeType), print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
get_ref_value_(prim::kPrimGetRefValue), get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
mirror_(prim::kPrimMirror), mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
virtual_div_(prim::kPrimVirtualDiv) { virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
eliminaters_.emplace_back(insert_gradient_of_); eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(stop_gradient_); eliminaters_.emplace_back(stop_gradient_);
eliminaters_.emplace_back(hook_backward_); eliminaters_.emplace_back(hook_backward_);
...@@ -53,10 +53,10 @@ class SpecialOpEliminater { ...@@ -53,10 +53,10 @@ class SpecialOpEliminater {
} }
~SpecialOpEliminater() = default; ~SpecialOpEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -65,9 +65,9 @@ class SpecialOpEliminater { ...@@ -65,9 +65,9 @@ class SpecialOpEliminater {
} }
private: private:
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
virtual_div_; virtual_div_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<OptimizerCallerPtr> eliminaters_{};
}; };
// {PrimVirtualDataset, X} -> X // {PrimVirtualDataset, X} -> X
......
...@@ -16,28 +16,27 @@ ...@@ -16,28 +16,27 @@
#include "optimizer/opt.h" #include "optimizer/opt.h"
#include <algorithm>
#include <deque>
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
#include <deque>
#include <algorithm>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/log_adapter.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) { const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr &node) -> bool { auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
...@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: ...@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) { const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action); return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
} }
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
double t = GetTime(); double t = GetTime();
#endif #endif
AnfNodePtr result = transform_(optimizer, node); AnfNodePtr result = (*transform_)(optimizer, node);
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
if (optimizer != nullptr) { if (optimizer != nullptr) {
auto time = GetTime(); auto time = GetTime();
......
...@@ -17,24 +17,18 @@ ...@@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#include <vector>
#include <string>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h" #include "operator/ops.h"
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
// Define the interaction mode between an Optimize pass and Renormalize pass // Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
...@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; ...@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class Substitution { class Substitution {
public: public:
TransformFuncType transform_{nullptr}; OptimizerCallerPtr transform_;
std::string name_; std::string name_;
PredicateFuncType predicate_{nullptr}; PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass // an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_; RenormAction renorm_action_;
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action) const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default; ~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
}; };
using SubstitutionPtr = std::shared_ptr<Substitution>; using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM); const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM); const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
class SubstitutionList { class SubstitutionList {
......
...@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co ...@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager); MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
TensorRedistribution tensor_redistribution; TensorRedistribution tensor_redistribution(false, true);
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
} }
...@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp ...@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager); MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
TensorRedistribution tensor_redistribution; TensorRedistribution tensor_redistribution(false, true);
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
} }
......
...@@ -62,6 +62,7 @@ void ParallelContext::Reset() { ...@@ -62,6 +62,7 @@ void ParallelContext::Reset() {
enable_all_reduce_fusion_ = false; enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = ""; strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = ""; strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
} }
void ParallelContext::set_device_num(int32_t device_num) { void ParallelContext::set_device_num(int32_t device_num) {
......
...@@ -100,6 +100,11 @@ class ParallelContext { ...@@ -100,6 +100,11 @@ class ParallelContext {
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer;
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void Reset(); void Reset();
private: private:
...@@ -123,6 +128,7 @@ class ParallelContext { ...@@ -123,6 +128,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_; std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_; std::string strategy_ckpt_save_file_;
bool enable_parallel_optimizer_;
}; };
void ParallelParameterContextInit(const FuncGraphPtr &func_graph); void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
......
...@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
"Get enable/disable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context."); .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
......
...@@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { ...@@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
} }
} // namespace } // namespace
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
......
...@@ -31,7 +31,7 @@ class BackendCSE : public CSE { ...@@ -31,7 +31,7 @@ class BackendCSE : public CSE {
public: public:
BackendCSE() = default; BackendCSE() = default;
~BackendCSE() override = default; ~BackendCSE() override = default;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override; bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; ...@@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
} // namespace mindspore } // namespace mindspore
...@@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; ...@@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];
} // namespace mindspore } // namespace mindspore
#endif // PYBIND_API_EXPORT_FLAGS_H_ #endif // PYBIND_API_EXPORT_FLAGS_H_
...@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; ...@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore { namespace mindspore {
namespace session { namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
return node;
}
}
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
return nullptr;
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) { const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) { if (memo->find(kg.get()) != memo->end()) {
...@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr ...@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
if (target_graph_iter == graph_id_map.end()) { if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
} }
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
NOT_NULL(parameter));
} }
} }
} }
...@@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr ...@@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} }
} }
kg->SetExecOrderByDefault();
MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
return NOT_NULL(start_label); return NOT_NULL(start_label);
} }
...@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A ...@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
return {partial_cnode, branch_kg}; return {partial_cnode, branch_kg};
} }
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) { NotNull<AnfNodePtr> to) {
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
...@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg ...@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
<< to_outputs.size() << "]"; << to_outputs.size() << "]";
} }
for (size_t i = 0; i < from_outputs.size(); i++) { for (size_t i = 0; i < from_outputs.size(); i++) {
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
}
} }
} }
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) { NotNull<AnfNodePtr> to) {
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
return; return nullptr;
} }
if (from.get() == to.get()) { if (from.get() == to.get()) {
return; return nullptr;
} }
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString(); << to->DebugString();
...@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul ...@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
assign_node->set_abstract(to->abstract()); assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph // append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node)); InsertDependToGraph(kg, NOT_NULL(assign_node));
return assign_node;
} }
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
......
...@@ -52,8 +52,9 @@ class AscendControlParser { ...@@ -52,8 +52,9 @@ class AscendControlParser {
const CNodePtr &last_label); const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
// root graph order // root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
......
...@@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) { ...@@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes; return output_nodes;
} }
// Find control_depend real input nodes.
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
// update the depend relations of control depend // update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) { void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) { for (const auto &node : depends) {
...@@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de ...@@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (depend_node->isa<Parameter>() && depend_mode == 1) { if (depend_node->isa<Parameter>() && depend_mode == 1) {
depend_nodes = GetOutputNodes(depend_node); depend_nodes = GetOutputNodes(depend_node);
} }
for (auto &first_node : prior_nodes) {
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
for (auto &first_node : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue; continue;
} }
for (auto &second_node : depend_nodes) { for (auto &second_node : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue; continue;
} }
......
...@@ -33,9 +33,14 @@ ...@@ -33,9 +33,14 @@
namespace py = pybind11; namespace py = pybind11;
namespace mindspore::inference { namespace mindspore::inference {
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) { std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
inference::Session::RegAllOp(); try {
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); inference::Session::RegAllOp();
return anf_graph; auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed";
return nullptr;
}
} }
void ExitInference() { void ExitInference() {
...@@ -51,12 +56,17 @@ void ExitInference() { ...@@ -51,12 +56,17 @@ void ExitInference() {
} }
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) { std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
auto session = std::make_shared<inference::Session>(); try {
auto ret = session->Init(device, device_id); auto session = std::make_shared<inference::Session>();
if (ret != 0) { auto ret = session->Init(device, device_id);
if (ret != 0) {
return nullptr;
}
return session;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CreatSession failed";
return nullptr; return nullptr;
} }
return session;
} }
void Session::RegAllOp() { void Session::RegAllOp() {
...@@ -113,47 +123,71 @@ void Session::RegAllOp() { ...@@ -113,47 +123,71 @@ void Session::RegAllOp() {
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr); MS_ASSERT(session_impl_ != nullptr);
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); try {
py::gil_scoped_release gil_release; auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
return graph_id; py::gil_scoped_release gil_release;
return graph_id;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CompileGraph failed";
return static_cast<uint32_t>(-1);
}
} }
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
std::vector<tensor::TensorPtr> inTensors; try {
inTensors.resize(inputs.size()); std::vector<tensor::TensorPtr> inTensors;
bool has_error = false; inTensors.resize(inputs.size());
std::transform(inputs.begin(), inputs.end(), inTensors.begin(), bool has_error = false;
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
if (tensor_ptr == nullptr) { [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; if (tensor_ptr == nullptr) {
has_error = true; MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
return nullptr; has_error = true;
} return nullptr;
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get()); }
if (tensor == nullptr) { auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; if (tensor == nullptr) {
has_error = true; MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
return nullptr; has_error = true;
} return nullptr;
return tensor->tensor(); }
}); return tensor->tensor();
if (has_error) { });
MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; if (has_error) {
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor; MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
return multiTensor; std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
} return multiTensor;
VectorRef outputs; }
session_impl_->RunGraph(graph_id, inTensors, &outputs); VectorRef outputs;
session_impl_->RunGraph(graph_id, inTensors, &outputs);
return TransformVectorRefToMultiTensor(outputs); return TransformVectorRefToMultiTensor(outputs);
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference Rungraph failed";
return MultiTensor();
}
} }
namespace {
string AjustTargetName(const std::string &device) {
if (device == kAscendDevice) {
return std::string(kAscendDevice) + "Inference";
} else {
MS_LOG(ERROR) << "Only support device Ascend right now";
return "";
}
}
} // namespace
int Session::Init(const std::string &device, uint32_t device_id) { int Session::Init(const std::string &device, uint32_t device_id) {
RegAllOp(); RegAllOp();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
ms_context->set_execution_mode(kGraphMode); ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_target(kAscendDevice); ms_context->set_device_id(device_id);
session_impl_ = session::SessionFactory::Get().Create(device); auto ajust_device = AjustTargetName(device);
if (ajust_device == "") {
return -1;
}
ms_context->set_device_target(device);
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
if (session_impl_ == nullptr) { if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
return -1; return -1;
......
...@@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne ...@@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
} }
} }
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
auto address = AnfAlgo::GetOutputAddr(node, output_index); DeviceAddressPtr address;
auto is_all_nop_node = opt::IsAllNopNode(&graph);
if (is_all_nop_node) {
// The graph does not remove the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
} else {
// The graph removes the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
}
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, output_index); auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
TypeId type_id = kNumberTypeFloat32; TypeId type_id = kNumberTypeFloat32;
...@@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne ...@@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); tensor->set_device_address(address);
tensor->set_dirty(false); tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
......
...@@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, ...@@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
} }
if (src_ops_list->empty() || dst_ops_list->empty()) { if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(WARNING) << "Control depend node's src or dest node is not a apply node, ignore it"; MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it";
error_ = SUCCESS; error_ = SUCCESS;
} }
return true; return true;
...@@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { ...@@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
}); });
} else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) {
control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]});
} else if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it";
} else { } else {
MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size()
<< " -> dst:" << dst_ops_list->size(); << " -> dst:" << dst_ops_list->size();
......
...@@ -463,7 +463,7 @@ void InitSubModulesLogLevel() { ...@@ -463,7 +463,7 @@ void InitSubModulesLogLevel() {
// set submodule's log level // set submodule's log level
auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); auto submodule = GetEnv("MS_SUBMODULE_LOG_v");
MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
LogConfigParser parser(submodule); LogConfigParser parser(submodule);
auto configs = parser.Parse(); auto configs = parser.Parse();
for (const auto &cfg : configs) { for (const auto &cfg : configs) {
...@@ -489,22 +489,14 @@ void InitSubModulesLogLevel() { ...@@ -489,22 +489,14 @@ void InitSubModulesLogLevel() {
} // namespace mindspore } // namespace mindspore
extern "C" { extern "C" {
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
__attribute__((constructor)) void mindspore_log_init(void) { __attribute__((constructor)) void common_log_init(void) {
#else #else
void mindspore_log_init(void) { void common_log_init(void) {
#endif #endif
#ifdef USE_GLOG #ifdef USE_GLOG
// do not use glog predefined log prefix // do not use glog predefined log prefix
FLAGS_log_prefix = false; FLAGS_log_prefix = false;
static bool is_glog_initialzed = false;
if (!is_glog_initialzed) {
#if !defined(_WIN32) && !defined(_WIN64)
google::InitGoogleLogging("mindspore");
#endif
is_glog_initialzed = true;
}
// set default log level to WARNING // set default log level to WARNING
if (mindspore::GetEnv("GLOG_v").empty()) { if (mindspore::GetEnv("GLOG_v").empty()) {
FLAGS_v = mindspore::WARNING; FLAGS_v = mindspore::WARNING;
...@@ -525,4 +517,22 @@ void mindspore_log_init(void) { ...@@ -525,4 +517,22 @@ void mindspore_log_init(void) {
#endif #endif
mindspore::InitSubModulesLogLevel(); mindspore::InitSubModulesLogLevel();
} }
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
__attribute__((constructor)) void mindspore_log_init(void) {
#else
void mindspore_log_init(void) {
#endif
#ifdef USE_GLOG
static bool is_glog_initialzed = false;
if (!is_glog_initialzed) {
#if !defined(_WIN32) && !defined(_WIN64)
google::InitGoogleLogging("mindspore");
#endif
is_glog_initialzed = true;
}
#endif
common_log_init();
}
} }
...@@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode"; ...@@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
// index define of depend // index define of depend
constexpr auto kRealInputIndexInDepend = 1; constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2; constexpr auto kDependAttachNodeIndex = 2;
constexpr auto kDependInputSize = 3;
// format // format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
......
...@@ -22,6 +22,10 @@ from . import dtype as mstype ...@@ -22,6 +22,10 @@ from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor'] __all__ = ['Tensor', 'MetaTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)
class Tensor(Tensor_): class Tensor(Tensor_):
...@@ -54,6 +58,10 @@ class Tensor(Tensor_): ...@@ -54,6 +58,10 @@ class Tensor(Tensor_):
""" """
def __init__(self, input_data, dtype=None): def __init__(self, input_data, dtype=None):
# If input data is numpy number, convert it to np array
if isinstance(input_data, np_types):
input_data = np.array(input_data)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int)) check_type('tensor input_data', input_data, (Tensor_, float, int))
if dtype is not None: if dtype is not None:
......
...@@ -1040,7 +1040,7 @@ class Dataset: ...@@ -1040,7 +1040,7 @@ class Dataset:
Args: Args:
columns (list[str], optional): List of columns to be used to specify the order of columns columns (list[str], optional): List of columns to be used to specify the order of columns
(defaults=None, means all columns). (default=None, means all columns).
Returns: Returns:
Iterator, list of ndarray. Iterator, list of ndarray.
...@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset): ...@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset):
class_indexing (dict, optional): A str-to-int mapping from label name to index class_indexing (dict, optional): A str-to-int mapping from label name to index
(default=None, the folder names will be sorted alphabetically and each (default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0). class will be given a unique index starting from 0).
decode (bool, optional): decode the images after reading (defaults=False). decode (bool, optional): decode the images after reading (default=False).
num_shards (int, optional): Number of shards that the dataset should be divided num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None). into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This shard_id (int, optional): The shard ID within num_shards (default=None). This
...@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset: ...@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset:
def process_dict(self, input_data): def process_dict(self, input_data):
""" """
Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first. Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
""" """
# Convert pandas like dict(has "values" column) into General dict # Convert pandas like dict(has "values" column) into General dict
data_keys = list(input_data.keys()) data_keys = list(input_data.keys())
......
...@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp): ...@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp):
Flip the input image horizontally, randomly with a given probability. Flip the input image horizontally, randomly with a given probability.
Args: Args:
prob (float): Probability of the image being flipped (default=0.5). prob (float, optional): Probability of the image being flipped (default=0.5).
""" """
@check_prob @check_prob
...@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp): ...@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
Maintains data integrity by also flipping bounding boxes in an object detection pipeline. Maintains data integrity by also flipping bounding boxes in an object detection pipeline.
Args: Args:
prob (float): Probability of the image being flipped (default=0.5). prob (float, optional): Probability of the image being flipped (default=0.5).
""" """
@check_prob @check_prob
...@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp): ...@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
Flip the input image vertically, randomly with a given probability. Flip the input image vertically, randomly with a given probability.
Args: Args:
prob (float): Probability of the image being flipped (default=0.5). prob (float, optional): Probability of the image being flipped (default=0.5).
""" """
@check_prob @check_prob
......
...@@ -29,8 +29,9 @@ from .optimizer import Optimizer ...@@ -29,8 +29,9 @@ from .optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt") _adam_opt = C.MultitypeFuncGraph("adam_opt")
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): "Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
""" """
Update parameters. Update parameters.
...@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad ...@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
m (Tensor): m value of parameters. m (Tensor): m value of parameters.
v (Tensor): v value of parameters. v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters. gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns: Returns:
Tensor, the new value of v after updating. Tensor, the new value of v after updating.
""" """
op_mul = P.Mul() if optim_filter:
op_square = P.Square() op_mul = P.Mul()
op_sqrt = P.Sqrt() op_square = P.Square()
op_cast = P.Cast() op_sqrt = P.Sqrt()
op_reshape = P.Reshape() op_cast = P.Cast()
op_shape = P.Shape() op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32) param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32) m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32) v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32)) - beta2, op_square(gradient_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
update_with_lr = op_mul(lr, update) update = next_m / (eps + op_sqrt(next_v))
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) update_with_lr = op_mul(lr, update)
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_v next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
...@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer): ...@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`. tuple[bool], all elements are True.
Examples: Examples:
>>> net = Net() >>> net = Net()
...@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer): ...@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor), self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
return updated_velocity if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayDynamicLR(Optimizer): class AdamWeightDecayDynamicLR(Optimizer):
...@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer): ...@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`. tuple[bool], all elements are True.
Examples: Examples:
>>> net = Net() >>> net = Net()
...@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer): ...@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor), self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step) F.control_depend(lr, added_global_step)
self.global_step = added_global_step self.global_step = added_global_step
return updated_velocity return optim_result
...@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32) ...@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt") _lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
gradient, decay_flag): gradient, decay_flag, optim_filter):
""" """
Update parameters. Update parameters.
...@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para ...@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
v (Tensor): v value of parameters. v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters. gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay. decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns: Returns:
Tensor, the new value of v after updating. Tensor, the new value of v after updating.
""" """
op_mul = P.Mul() if optim_filter:
op_sqrt = P.Sqrt() op_mul = P.Mul()
op_rsqrt = P.Rsqrt() op_sqrt = P.Sqrt()
op_square = P.Square() op_rsqrt = P.Rsqrt()
op_cast = P.Cast() op_square = P.Square()
op_reshape = P.Reshape() op_cast = P.Cast()
op_shape = P.Shape() op_reshape = P.Reshape()
op_pow = P.Pow() op_shape = P.Shape()
op_norm = layer.Norm() op_pow = P.Pow()
op_select = P.Select() op_norm = layer.Norm()
op_greater = P.Greater() op_select = P.Select()
op_fill = P.Fill() op_greater = P.Greater()
op_dtype = P.DType() op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32) param_fp32 = op_cast(param, mstype.float32)
v_fp32 = op_cast(v, mstype.float32) m_fp32 = op_cast(m, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32) v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta1, gradient_fp32) next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
next_mm = next_m / (op_cast(num_one, mstype.float32) - op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) next_vv = next_v / (op_cast(num_one, mstype.float32) -
next_vv = next_v / (op_cast(num_one, mstype.float32) - op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) w_norm = op_norm(param_fp32)
w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt( zeros = F.zeros_like(w_norm)
next_vv + eps)) + weight_decay_tensor * param_fp32) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
zeros = F.zeros_like(w_norm) trust_ratio = op_select(
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) op_greater(w_norm, zeros),
trust_ratio = op_select( op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
op_greater(w_norm, zeros), ones)
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
ones) trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) update = next_mm / (op_sqrt(next_vv) + eps)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps) if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32) update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) next_param = F.depend(next_param, F.assign(param, next_param))
next_param = F.depend(next_param, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(param, next_param)) next_param = F.depend(next_param, F.assign(v, next_v))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v)) return next_param
return gradient
return next_v
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
...@@ -238,7 +237,7 @@ class Lamb(Optimizer): ...@@ -238,7 +237,7 @@ class Lamb(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs: Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`. tuple[bool], all elements are True.
Examples: Examples:
>>> net = Net() >>> net = Net()
...@@ -311,18 +310,21 @@ class Lamb(Optimizer): ...@@ -311,18 +310,21 @@ class Lamb(Optimizer):
self.warmup_steps, self.global_step), mstype.float32) self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
if self.enable_graph_kernel: if self.enable_graph_kernel:
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel, optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step), self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients, self.decay_flag)
else: else:
updated_velocity = self.hyper_map(F.partial(_lamb_opt, optim_result = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step), self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step) F.control_depend(lr, added_global_step)
self.global_step = added_global_step self.global_step = added_global_step
return updated_velocity return optim_result
...@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P ...@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode
__all__ = ['Optimizer'] __all__ = ['Optimizer']
...@@ -155,6 +158,27 @@ class Optimizer(Cell): ...@@ -155,6 +158,27 @@ class Optimizer(Cell):
self.param_length = len(self.parameters) self.param_length = len(self.parameters)
self.map_ = C.Map() self.map_ = C.Map()
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
def decay_weight(self, gradients): def decay_weight(self, gradients):
""" """
Weight decay. Weight decay.
...@@ -219,8 +243,32 @@ class Optimizer(Cell): ...@@ -219,8 +243,32 @@ class Optimizer(Cell):
raise TypeError("Learning rate should be float, Tensor or Iterable.") raise TypeError("Learning rate should be float, Tensor or Iterable.")
return lr return lr
def _check_group_params(self, parameters):
"""Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
for group_param in parameters:
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
if invalid_key:
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
if 'order_params' in group_param.keys():
if len(group_param.keys()) > 1:
raise ValueError("The order params dict in group parameters should "
"only include the 'order_params' key.")
if not isinstance(group_param['order_params'], Iterable):
raise TypeError("The value of 'order_params' should be an Iterable type.")
continue
if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")
for param in group_param['params']:
if not isinstance(param, Parameter):
raise TypeError("The group param should be an iterator of Parameter type.")
def _parse_group_params(self, parameters, learning_rate): def _parse_group_params(self, parameters, learning_rate):
"""Parse group params.""" """Parse group params."""
self._check_group_params(parameters)
if self.dynamic_lr: if self.dynamic_lr:
dynamic_lr_length = learning_rate.size() dynamic_lr_length = learning_rate.size()
else: else:
...@@ -250,9 +298,6 @@ class Optimizer(Cell): ...@@ -250,9 +298,6 @@ class Optimizer(Cell):
if dynamic_lr_length not in (lr_length, 0): if dynamic_lr_length not in (lr_length, 0):
raise ValueError("The dynamic learning rate in group should be the same size.") raise ValueError("The dynamic learning rate in group should be the same size.")
if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")
dynamic_lr_length = lr_length dynamic_lr_length = lr_length
self.dynamic_lr_length = dynamic_lr_length self.dynamic_lr_length = dynamic_lr_length
...@@ -384,6 +429,51 @@ class Optimizer(Cell): ...@@ -384,6 +429,51 @@ class Optimizer(Cell):
lr = self.learning_rate lr = self.learning_rate
return lr return lr
def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.
Returns:
tuple, the group id tuple of parameters.
"""
rank_list = ()
count = 0
for _ in range(self.param_length):
rank_list = rank_list + (count,)
count = count + 1
if count == self.dev_num:
count = 0
return rank_list
def broadcast_params(self, optim_result):
"""
Apply Broadcast operations in the sequential order of parameter groups.
Returns:
bool, the status flag.
"""
param_group = []
key_group = []
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
key_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = []
for root in range(self.dev_num):
ops = P.Broadcast(root)
next_params = ops(param_group[root])
new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i])
status = True
for i in range(self.dev_num - 1):
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
return status
def construct(self, *hyper_params): def construct(self, *hyper_params):
raise NotImplementedError raise NotImplementedError
......
...@@ -220,7 +220,9 @@ class DataWrapper(Cell): ...@@ -220,7 +220,9 @@ class DataWrapper(Cell):
def __init__(self, network, dataset_types, dataset_shapes, queue_name): def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network self.network = network
......
...@@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg ...@@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg
from .less import _less_akg from .less import _less_akg
from .log import _log_akg from .log import _log_akg
from .matmul import _matmul_akg from .matmul import _matmul_akg
from .batchmatmul import _batchmatmul_akg
from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg
from .max_pool_with_argmax import _max_pool_with_argmax_akg from .max_pool_with_argmax import _max_pool_with_argmax_akg
from .max import _max_akg from .max import _max_akg
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "BatchMatMul",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"attr": [
{
"name": "transpose_a",
"param_type": "optional",
"type": "bool"
},
{
"name": "transpose_b",
"param_type": "optional",
"type": "bool"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x1"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "output"
}
]
}""")
def _batchmatmul_akg():
"""BatchMatMul AKG register"""
return
...@@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \ ...@@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
.attr("transpose_first", "required", "bool", "all") \ .attr("transpose_first", "required", "bool", "all") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \ .op_pattern("dynamicFormat") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.None_None, DataType.None_None) \
.dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info() .get_op_info()
......
...@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value): ...@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
return F.list_setitem(data, number_index, value) return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "Tuple")
def _list_setitem_with_Tuple(data, number_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (list): Value given.
Outputs:
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("Dictionary", "String", "Tensor") @setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value): def _dict_setitem_with_tensor(data, key, value):
""" """
......
...@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer): ...@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
self.op = op self.op = op
self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0) self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)
def vm_impl(self, x): def vm_impl(self, x):
"""Implement by vm mode.""" """Implement by vm mode."""
......
...@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer): ...@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
Output tensor or string to stdout. Output tensor or string to stdout.
Note: Note:
The print operation cannot support the following cases currently.
1. The type of tensor is float64 or bool.
2. The data of tensor is a scalar type.
In pynative mode, please use python print function. In pynative mode, please use python print function.
Inputs: Inputs:
...@@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer): ...@@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
pass self.add_prim_attr("_side_effect", True)
def __call__(self, *args): def __call__(self, *args):
for arg in args: for arg in args:
......
...@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer): ...@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer):
def infer_value(self, input_x): def infer_value(self, input_x):
if input_x is not None: if input_x is not None:
input_x = input_x.asnumpy() input_x = input_x.asnumpy()
return Tensor(-input_x) out = np.array(-input_x, input_x.dtype)
return Tensor(out)
return None return None
...@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp): ...@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp):
if x is not None and y is not None: if x is not None and y is not None:
x = x.asnumpy() x = x.asnumpy()
y = y.asnumpy() y = y.asnumpy()
return Tensor(x / y) out = np.array(x / y, x.dtype)
return Tensor(out)
return None return None
......
...@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer): ...@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
return variable return variable
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value} # Add a type validation later when we don't have to assign a value to RefKey.
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return variable return variable
......
...@@ -400,6 +400,23 @@ class _AutoParallelContext: ...@@ -400,6 +400,23 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_global_rank_is_set() return self._context_handle.get_global_rank_is_set()
def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
"""
Set enable/disable parallel optimizer.
Args:
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool):
raise TypeError('enable_parallel_optimizer is invalid type')
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
def get_enable_parallel_optimizer(self):
"""Get parallel optimizer flag."""
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()
def reset(self): def reset(self):
"""Reset all settings.""" """Reset all settings."""
self.check_context_handle() self.check_context_handle()
...@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = { ...@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch} "full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}
_get_auto_parallel_context_func_map = { _get_auto_parallel_context_func_map = {
...@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = { ...@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch} "full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
...@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs): ...@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
...@@ -535,5 +556,6 @@ def _reset_auto_parallel_context(): ...@@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "" - strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: "" - strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
""" """
auto_parallel_context().reset() auto_parallel_context().reset()
...@@ -166,8 +166,11 @@ class SummaryCollector(Callback): ...@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False self._has_saved_custom_data = False
self._is_parse_loss_success = True self._is_parse_loss_success = True
self._first_step = True self._first_step = True
self._dataset_sink_mode = True
def __enter__(self): def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir) self._record = SummaryRecord(log_dir=self._summary_dir)
return self return self
...@@ -279,15 +282,15 @@ class SummaryCollector(Callback): ...@@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def step_end(self, run_context): def step_end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value: if cb_params.mode == ModeEnum.TRAIN.value:
# Make sure the first step data is recorded if not self._is_collect_this_step(cb_params):
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
return return
self._first_step = False
if not self._has_saved_train_network: if not self._has_saved_train_network:
self._collect_graphs(cb_params) self._collect_graphs(cb_params)
...@@ -295,6 +298,7 @@ class SummaryCollector(Callback): ...@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params) self._collect_metric(cb_params)
self._collect_histogram(cb_params) self._collect_histogram(cb_params)
self._first_step = False
self._record.record(cb_params.cur_step_num) self._record.record(cb_params.cur_step_num)
def end(self, run_context): def end(self, run_context):
...@@ -320,6 +324,18 @@ class SummaryCollector(Callback): ...@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.") f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod @staticmethod
def _package_custom_lineage_data(custom_lineage_data): def _package_custom_lineage_data(custom_lineage_data):
""" """
......
...@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training): ...@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training):
else: else:
input_data = resize_column(*input_data) input_data = resize_column(*input_data)
photo = (np.random.rand() < config.photo_ratio)
if photo:
input_data = photo_crop_column(*input_data)
input_data = image_bgr_rgb(*input_data) input_data = image_bgr_rgb(*input_data)
output_data = input_data output_data = input_data
...@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast ...@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
writer.write_raw_data([row]) writer.write_raw_data([row])
writer.commit() writer.commit()
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0, def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=8): is_training=True, num_parallel_workers=4):
"""Creatr FasterRcnn dataset with MindDataset.""" """Creatr FasterRcnn dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id, ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=num_parallel_workers, shuffle=is_training) num_parallel_workers=1, shuffle=is_training)
decode = C.Decode() decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode) ds = ds.map(input_columns=["image"], operations=decode, num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = C.HWC2CHW() hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1) horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16) type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32) type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_) type_cast3 = CC.TypeCast(mstype.bool_)
...@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi ...@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
ds = ds.map(input_columns=["image", "annotation"], ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"], output_columns=["image", "image_shape", "box", "label", "valid_num"],
columns_order=["image", "image_shape", "box", "label", "valid_num"], columns_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func, num_parallel_workers=4) operations=compose_map_func, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
flip = (np.random.rand() < config.flip_ratio) flip = (np.random.rand() < config.flip_ratio)
if flip: if flip:
ds = ds.map(input_columns=["image"], operations=[horizontally_op], ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1],
num_parallel_workers=num_parallel_workers) num_parallel_workers=24)
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
operations=flipped_generation, num_parallel_workers=4) operations=flipped_generation, num_parallel_workers=num_parallel_workers)
else:
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
else: else:
ds = ds.map(input_columns=["image", "annotation"], ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"], output_columns=["image", "image_shape", "box", "label", "valid_num"],
...@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi ...@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
operations=compose_map_func, operations=compose_map_func,
num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=num_parallel_workers) num_parallel_workers=24)
# transpose_column from python to c # transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1]) ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2]) ds = ds.map(input_columns=["label"], operations=[type_cast2])
......
...@@ -19,7 +19,9 @@ from easydict import EasyDict as edict ...@@ -19,7 +19,9 @@ from easydict import EasyDict as edict
cifar_cfg = edict({ cifar_cfg = edict({
'num_classes': 10, 'num_classes': 10,
'lr_init': 0.05, 'lr_init': 0.01,
'lr_max': 0.1,
'warmup_epochs': 5,
'batch_size': 64, 'batch_size': 64,
'epoch_size': 70, 'epoch_size': 70,
'momentum': 0.9, 'momentum': 0.9,
......
...@@ -38,20 +38,25 @@ random.seed(1) ...@@ -38,20 +38,25 @@ random.seed(1)
np.random.seed(1) np.random.seed(1)
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""Set learning rate.""" """Set learning rate."""
lr_each_step = [] lr_each_step = []
total_steps = steps_per_epoch * total_epochs total_steps = steps_per_epoch * total_epochs
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] warmup_steps = steps_per_epoch * warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps): for i in range(total_steps):
if i < decay_epoch_index[0]: if i < warmup_steps:
lr_each_step.append(lr_max) lr_value = float(lr_init) + inc_each_step * float(i)
elif i < decay_epoch_index[1]:
lr_each_step.append(lr_max * 0.1)
elif i < decay_epoch_index[2]:
lr_each_step.append(lr_max * 0.01)
else: else:
lr_each_step.append(lr_max * 0.001) base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr_value = float(lr_max) * base * base
if lr_value < 0.0:
lr_value = 0.0
lr_each_step.append(lr_value)
current_step = global_step current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32) lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:] learning_rate = lr_each_step[current_step:]
...@@ -86,7 +91,8 @@ if __name__ == '__main__': ...@@ -86,7 +91,8 @@ if __name__ == '__main__':
if args_opt.pre_trained: if args_opt.pre_trained:
load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) load_param_into_net(net, load_checkpoint(args_opt.pre_trained))
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay) weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <memory> #include <memory>
#include <future>
#include "mindspore/ccsrc/utils/log_adapter.h" #include "mindspore/ccsrc/utils/log_adapter.h"
#include "serving/ms_service.grpc.pb.h" #include "serving/ms_service.grpc.pb.h"
...@@ -40,7 +41,7 @@ namespace serving { ...@@ -40,7 +41,7 @@ namespace serving {
using MSTensorPtr = std::shared_ptr<inference::MSTensor>; using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
session_ = inference::MSSession::CreateSession(device + "Inference", device_id); session_ = inference::MSSession::CreateSession(device, device_id);
if (session_ == nullptr) { if (session_ == nullptr) {
MS_LOG(ERROR) << "Creat Session Failed"; MS_LOG(ERROR) << "Creat Session Failed";
return FAILED; return FAILED;
...@@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi ...@@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
MS_LOG(INFO) << "run Predict"; MS_LOG(INFO) << "run Predict";
*outputs = session_->RunGraph(graph_id_, inputs); *outputs = session_->RunGraph(graph_id_, inputs);
MS_LOG(INFO) << "run Predict finished";
return SUCCESS; return SUCCESS;
} }
...@@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) { ...@@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) {
std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
char *graphBuf = ReadFile(file_name.c_str(), &size); char *graphBuf = ReadFile(file_name.c_str(), &size);
if (graphBuf == nullptr) { if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return FAILED; return FAILED;
} }
last_graph_ = inference::LoadModel(graphBuf, size, device_type_); last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
if (last_graph_ == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED;
}
graph_id_ = session_->CompileGraph(last_graph_); graph_id_ = session_->CompileGraph(last_graph_);
MS_LOG(INFO) << "Session Warmup"; MS_LOG(INFO) << "Session Warmup finished";
return SUCCESS; return SUCCESS;
} }
...@@ -95,6 +101,9 @@ Status Session::Clear() { ...@@ -95,6 +101,9 @@ Status Session::Clear() {
} }
namespace { namespace {
static const uint32_t uint32max = 0x7FFFFFFF;
std::promise<void> exit_requested;
const std::map<ms_serving::DataType, TypeId> type2id_map{ const std::map<ms_serving::DataType, TypeId> type2id_map{
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
...@@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { ...@@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
} }
TypeId type = iter->second; TypeId type = iter->second;
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape)); auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size()); memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
return ms_tensor; return ms_tensor;
} }
...@@ -166,10 +175,7 @@ void ClearEnv() { ...@@ -166,10 +175,7 @@ void ClearEnv() {
Session::Instance().Clear(); Session::Instance().Clear();
inference::ExitInference(); inference::ExitInference();
} }
void HandleSignal(int sig) { void HandleSignal(int sig) { exit_requested.set_value(); }
ClearEnv();
exit(0);
}
#ifdef ENABLE_D #ifdef ENABLE_D
static rtContext_t g_ctx = nullptr; static rtContext_t g_ctx = nullptr;
...@@ -247,6 +253,7 @@ Status Server::BuildAndStart() { ...@@ -247,6 +253,7 @@ Status Server::BuildAndStart() {
rtError_t rt_ret = rtCtxGetCurrent(&ctx); rtError_t rt_ret = rtCtxGetCurrent(&ctx);
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null"; MS_LOG(ERROR) << "the ascend device context is null";
ClearEnv();
return FAILED; return FAILED;
} }
g_ctx = ctx; g_ctx = ctx;
...@@ -258,6 +265,7 @@ Status Server::BuildAndStart() { ...@@ -258,6 +265,7 @@ Status Server::BuildAndStart() {
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
grpc::ServerBuilder builder; grpc::ServerBuilder builder;
builder.SetOption(std::move(option)); builder.SetOption(std::move(option));
builder.SetMaxMessageSize(uint32max);
// Listen on the given address without any authentication mechanism. // Listen on the given address without any authentication mechanism.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Register "service" as the instance through which we'll communicate with // Register "service" as the instance through which we'll communicate with
...@@ -265,13 +273,20 @@ Status Server::BuildAndStart() { ...@@ -265,13 +273,20 @@ Status Server::BuildAndStart() {
builder.RegisterService(&service); builder.RegisterService(&service);
// Finally assemble the server. // Finally assemble the server.
std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
if (server == nullptr) {
MS_LOG(ERROR) << "The serving server create failed";
ClearEnv();
return FAILED;
}
auto grpc_server_run = [&server]() { server->Wait(); };
std::thread serving_thread(grpc_server_run);
MS_LOG(INFO) << "Server listening on " << server_address << std::endl; MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
auto exit_future = exit_requested.get_future();
// Wait for the server to shutdown. Note that some other thread must be exit_future.wait();
// responsible for shutting down the server for this call to ever return. ClearEnv();
server->Wait(); server->Shutdown();
serving_thread.join();
return SUCCESS; return SUCCESS;
} }
} // namespace serving } // namespace serving
} // namespace mindspore } // namespace mindspore
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
namespace mindspore { namespace mindspore {
namespace serving { namespace serving {
char *ReadFile(const char *file, size_t *size) { char *ReadFile(const char *file, size_t *size) {
if (file == nullptr) { if (file == nullptr) {
MS_LOG(ERROR) << "file is nullptr"; MS_LOG(ERROR) << "file is nullptr";
...@@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) { ...@@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) {
} }
std::vector<std::string> GetAllSubDirs(const std::string &dir_path) { std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
DIR *dir; DIR *dir = nullptr;
struct dirent *ptr; struct dirent *ptr = nullptr;
std::vector<std::string> SubDirs; std::vector<std::string> SubDirs;
if ((dir = opendir(dir_path.c_str())) == NULL) { if ((dir = opendir(dir_path.c_str())) == NULL) {
......
...@@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) { ...@@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) {
bool Option::ParseInt32(std::string *arg) { bool Option::ParseInt32(std::string *arg) {
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
char extra;
int32_t parsed_value; int32_t parsed_value;
if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) { try {
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; parsed_value = std::stoi(arg->data());
} catch (std::invalid_argument) {
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
return false; return false;
} else {
*int32_default_ = parsed_value;
} }
*int32_default_ = parsed_value;
return true; return true;
} }
return false; return false;
} }
...@@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) { ...@@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) {
bool Option::ParseFloat(std::string *arg) { bool Option::ParseFloat(std::string *arg) {
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
char extra;
float parsed_value; float parsed_value;
if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) { try {
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; parsed_value = std::stof(arg->data());
} catch (std::invalid_argument) {
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
return false; return false;
} else {
*float_default_ = parsed_value;
} }
*float_default_ = parsed_value;
return true; return true;
} }
return false; return false;
} }
...@@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); } ...@@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); }
void Options::CreateOptions() { void Options::CreateOptions() {
args_ = std::make_shared<Arguments>(); args_ = std::make_shared<Arguments>();
std::vector<Option> options = { std::vector<Option> options = {
Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"), Option("port", &args_->grpc_port,
Option("model_name", &args_->model_name, "model name "), "[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"),
Option("model_path", &args_->model_path, "the path of the model files"), Option("model_name", &args_->model_name, "[Required] model name "),
Option("device_id", &args_->device_id, "the device id, default is 0"), Option("model_path", &args_->model_path, "[Required] the path of the model files"),
Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"),
}; };
options_ = options; options_ = options;
} }
...@@ -176,6 +175,14 @@ bool Options::CheckOptions() { ...@@ -176,6 +175,14 @@ bool Options::CheckOptions() {
std::cout << "device_type only support Ascend right now" << std::endl; std::cout << "device_type only support Ascend right now" << std::endl;
return false; return false;
} }
if (args_->device_id > 7) {
std::cout << "the device_id should be in [0~7]" << std::endl;
return false;
}
if (args_->grpc_port < 1 || args_->grpc_port > 65535) {
std::cout << "the port should be in [1~65535]" << std::endl;
return false;
}
return true; return true;
} }
...@@ -238,6 +245,5 @@ void Options::Usage() { ...@@ -238,6 +245,5 @@ void Options::Usage() {
<< option.usage_ << std::endl; << option.usage_ << std::endl;
} }
} }
} // namespace serving } // namespace serving
} // namespace mindspore } // namespace mindspore
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
namespace mindspore { namespace mindspore {
namespace serving { namespace serving {
struct Arguments { struct Arguments {
int32_t grpc_port = 5500; int32_t grpc_port = 5500;
std::string grpc_socket_path; std::string grpc_socket_path;
...@@ -40,6 +39,7 @@ class Option { ...@@ -40,6 +39,7 @@ class Option {
Option(const std::string &name, bool *default_point, const std::string &usage); Option(const std::string &name, bool *default_point, const std::string &usage);
Option(const std::string &name, std::string *default_point, const std::string &usage); Option(const std::string &name, std::string *default_point, const std::string &usage);
Option(const std::string &name, float *default_point, const std::string &usage); Option(const std::string &name, float *default_point, const std::string &usage);
~Option() = default;
private: private:
friend class Options; friend class Options;
...@@ -77,7 +77,6 @@ class Options { ...@@ -77,7 +77,6 @@ class Options {
std::vector<Option> options_; std::vector<Option> options_;
std::shared_ptr<Arguments> args_; std::shared_ptr<Arguments> args_;
}; };
} // namespace serving } // namespace serving
} // namespace mindspore } // namespace mindspore
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
namespace mindspore { namespace mindspore {
namespace serving { namespace serving {
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path, MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
const std::string &model_version, const time_t &last_update_time) const std::string &model_version, const time_t &last_update_time)
: model_name_(model_name), : model_name_(model_name),
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
namespace mindspore { namespace mindspore {
namespace serving { namespace serving {
volatile bool stop_poll = false; volatile bool stop_poll = false;
std::string GetVersionFromPath(const std::string &path) { std::string GetVersionFromPath(const std::string &path) {
...@@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() { ...@@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() {
} }
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_); std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
if (version_control_strategy_ == kLastest) { if (version_control_strategy_ == kLastest) {
auto path = SubDirs.empty() ? models_path_ : SubDirs.back(); std::string model_version = GetVersionFromPath(models_path_);
std::string model_version = GetVersionFromPath(path); time_t last_update_time = GetModifyTime(models_path_);
time_t last_update_time = GetModifyTime(path); MindSporeModelPtr model_ptr =
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time); std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time);
valid_models_.emplace_back(model_ptr); valid_models_.emplace_back(model_ptr);
} else { } else {
for (auto &dir : SubDirs) { for (auto &dir : SubDirs) {
...@@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() { ...@@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() {
MS_LOG(ERROR) << "There is no valid model for serving"; MS_LOG(ERROR) << "There is no valid model for serving";
return FAILED; return FAILED;
} }
Session::Instance().Warmup(valid_models_.back()); auto ret = Session::Instance().Warmup(valid_models_.back());
return SUCCESS; return ret;
} }
void VersionController::StartPollModelPeriodic() { void VersionController::StartPollModelPeriodic() {
...@@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() { ...@@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() {
} }
void VersionController::StopPollModelPeriodic() {} void VersionController::StopPollModelPeriodic() {}
} // namespace serving } // namespace serving
} // namespace mindspore } // namespace mindspore
...@@ -64,7 +64,6 @@ class PeriodicFunction { ...@@ -64,7 +64,6 @@ class PeriodicFunction {
VersionController::VersionControllerStrategy version_control_strategy_; VersionController::VersionControllerStrategy version_control_strategy_;
std::vector<MindSporeModelPtr> valid_models_; std::vector<MindSporeModelPtr> valid_models_;
}; };
} // namespace serving } // namespace serving
} // namespace mindspore } // namespace mindspore
......
...@@ -214,6 +214,7 @@ PredictRequest ReadBertInput() { ...@@ -214,6 +214,7 @@ PredictRequest ReadBertInput() {
class MSClient { class MSClient {
public: public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default;
std::string Predict(const std::string &type) { std::string Predict(const std::string &type) {
// Data we are sending to the server. // Data we are sending to the server.
...@@ -310,7 +311,6 @@ int main(int argc, char **argv) { ...@@ -310,7 +311,6 @@ int main(int argc, char **argv) {
type = "add"; type = "add";
} }
} }
} else { } else {
target_str = "localhost:5500"; target_str = "localhost:5500";
type = "add"; type = "add";
......
...@@ -81,7 +81,7 @@ function checkopts() ...@@ -81,7 +81,7 @@ function checkopts()
checkopts "$@" checkopts "$@"
# switch to project root path, which contains clang-format config file '.clang-format' # switch to project root path, which contains clang-format config file '.clang-format'
cd "${SCRIPTS_PATH}/.." || exit 1 cd "${SCRIPTS_PATH}/../.." || exit 1
FMT_FILE_LIST='__format_files_list__' FMT_FILE_LIST='__format_files_list__'
......
...@@ -161,6 +161,7 @@ setup( ...@@ -161,6 +161,7 @@ setup(
description='MindSpore is a new open source deep learning training/inference ' description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.', 'framework that could be used for mobile, edge and cloud scenarios.',
long_description="\n\n".join([readme, release]), long_description="\n\n".join([readme, release]),
long_description_content_type="text/markdown",
packages=find_packages(), packages=find_packages(),
package_data=package_data, package_data=package_data,
include_package_data=True, include_package_data=True,
......
...@@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) { ...@@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) {
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
uint64_t min = ai.min_key(); uint64_t min = ai.min_key();
uint64_t max = ai.max_key(); uint64_t max = ai.max_key();
EXPECT_EQ(min, 1); EXPECT_EQ(min, 0);
EXPECT_EQ(max, 4); EXPECT_EQ(max, 3);
auto r = ai.Search(3); auto r = ai.Search(2);
auto &it = r.first; auto &it = r.first;
EXPECT_EQ(it.value(), "b"); EXPECT_EQ(it.value(), "b");
MS_LOG(INFO) << "Dump all the values using [] operator."; MS_LOG(INFO) << "Dump all the values using [] operator.";
......
...@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { ...@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
}; };
void SetUp() { void SetUp() {
elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
} }
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
......
...@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) { ...@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "instance_name") { } else if (name == "instance_name") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "test"); ASSERT_EQ(converted_ret->ToString(), "test");
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else { } else {
MS_LOG(EXCEPTION) << "Test failed"; MS_LOG(EXCEPTION) << "Test failed";
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"numParallelWorkers": 4, "numParallelWorkers": 4,
"workerConnectorSize": 16, "workerConnectorSize": 16,
"opConnectorSize": 16, "opConnectorSize": 16,
"seed": 5489 "seed": 5489,
"monitor_sampling_interval": 15
} }
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册