提交 5a886794 编写于 作者: Y yanghaitao

Merge branch 'master' of gitee.com:mindspore/mindspore

akg @ df57a6cf
Subproject commit c460176523d039c8995f1d71089753725ebc0792
Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35
......@@ -277,10 +277,11 @@ endif ()
if (USE_GLOG)
target_link_libraries(inference PRIVATE mindspore::glog)
else()
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()
endif()
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,common_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()
......@@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK(input);
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal");
(*output).push_back(nullptr); // init memory for return vector
(*output).push_back(nullptr);
output->resize(2);
(*output)[1] = std::move(input[1]); // move boxes over to output
size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor
......
......@@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
int32_t padded_image_h;
int32_t padded_image_w;
(*output).push_back(nullptr);
(*output).push_back(nullptr);
output->resize(2);
(*output)[1] = std::move(input[1]); // since some boxes may be removed
bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches
......
......@@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *
RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y));
}
(*output).push_back(nullptr);
(*output).push_back(nullptr);
output->resize(2);
(*output)[1] = std::move(input[1]);
return VerticalFlip(input[0], &(*output)[0]);
......
......@@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(utils OBJECT
arena.cc
buddy.cc
cache_pool.cc
circular_pool.cc
memory_pool.cc
cond_var.cc
......@@ -11,7 +13,11 @@ add_library(utils OBJECT
service.cc
services.cc
lock.cc
semaphore.cc
status.cc
storage_container.cc
storage_manager.cc
slice.cc
path.cc
wait_post.cc
sig_handler.cc)
......@@ -17,8 +17,10 @@
#define DATASET_UTIL_ALLOCATOR_H_
#include <cstdlib>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include "dataset/util/memory_pool.h"
namespace mindspore {
......@@ -84,6 +86,91 @@ class Allocator {
private:
std::shared_ptr<MemoryPool> pool_;
};
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
/// Default to std::allocator
template <typename T, typename C = std::allocator<T>>
class MemGuard {
public:
using allocator = C;
MemGuard() : n_(0) {}
explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
// There is no copy constructor nor assignment operator because the memory is solely owned by this object.
MemGuard(const MemGuard &) = delete;
MemGuard &operator=(const MemGuard &) = delete;
// On the other hand, We can support move constructor
MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
MemGuard &operator=(MemGuard &&lhs) noexcept {
if (this != &lhs) {
this->deallocate();
n_ = lhs.n_;
alloc_ = std::move(lhs.alloc_);
ptr_ = std::move(lhs.ptr_);
}
return *this;
}
/// \brief Explicitly deallocate the memory if allocated
void deallocate() {
if (ptr_) {
auto *p = ptr_.release();
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
for (auto i = 0; i < n_; ++i) {
p[i].~T();
}
}
alloc_.deallocate(p, n_);
n_ = 0;
}
}
/// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
/// allocated.
/// \param n Number of objects of type T to be allocated
/// \tparam Args Extra arguments pass to the constructor of T
template <typename... Args>
Status allocate(size_t n, Args &&... args) noexcept {
try {
deallocate();
if (n > 0) {
T *data = alloc_.allocate(n);
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...);
}
}
ptr_ = std::unique_ptr<T[]>(data);
n_ = n;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
~MemGuard() noexcept { deallocate(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetPointer() const { return ptr_.get(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetMutablePointer() { return ptr_.get(); }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) { return GetMutablePointer() + x; }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) const { return GetPointer() + x; }
/// \brief Return how many bytes are allocated in total
/// \return Number of bytes allocated in total
size_t GetSizeInBytes() const { return n_ * sizeof(T); }
private:
allocator alloc_;
std::unique_ptr<T[], std::function<void(T *)>> ptr_;
size_t n_;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
}
private:
static constexpr key_type kMinKey = 1;
static constexpr key_type kMinKey = 0;
std::atomic<key_type> inx_;
};
} // namespace dataset
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/buddy.h"
#include <iomanip>
#include <stdexcept>
#include "dataset/util/de_error.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/system_pool.h"
#include "./securec.h"
inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); }
inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); }
inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; }
inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; }
inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; }
namespace mindspore {
namespace dataset {
Status BuddySpace::Init() {
if (log_min_ < 0) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"log_min must be positive : " + std::to_string(log_min_));
}
if (num_lvl_ < 3 || num_lvl_ > 18) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_));
}
min_ = BitLeftShift(1, log_min_);
max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1);
size_t offset_1 = sizeof(rel_addr_t) * num_lvl_;
size_t offset_2 = sizeof(int) * num_lvl_ + offset_1;
size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2;
RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true));
hint_ = reinterpret_cast<rel_addr_t *>(ptr_);
count_ = reinterpret_cast<int *>((reinterpret_cast<char *>(ptr_) + offset_1));
map_ = reinterpret_cast<char *>(ptr_) + offset_2;
count_[num_lvl_ - 1] = 1;
map_[0] = BitOr(MORE_BIT, num_lvl_ - 3);
return Status::OK();
}
Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept {
std::lock_guard<std::mutex> lock(mutex_);
addr_t addr = AllocNoLock(sz, desc);
if (addr != NOSPACE) {
*p = addr;
return Status::OK();
} else {
return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore.");
}
}
addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept {
DS_ASSERT(sz <= max_);
uint32_t reqSize = SizeToBlock(sz);
rel_addr_t rel_addr = AllocBuddySeg(reqSize);
if (rel_addr != static_cast<rel_addr_t>(NOSPACE)) {
(void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor));
desc->sig = static_cast<int>(0xDEADBEEF);
desc->addr = rel_addr;
desc->req_size = reqSize;
desc->blk_size = NextPowerOf2(reqSize);
return static_cast<addr_t>(rel_addr * min_);
} else {
return NOSPACE;
}
}
void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) {
DS_ASSERT(desc->sig == 0XDEADBEEF);
rel_addr_t rel_addr = desc->addr;
size_t blk_size = desc->blk_size;
size_t req_size = desc->req_size;
FreeBuddySeg(rel_addr, blk_size, req_size);
}
void BuddySpace::Free(const BSpaceDescriptor *desc) {
std::lock_guard<std::mutex> lock(mutex_);
return FreeNoLock(desc);
}
std::ostream &operator<<(std::ostream &os, const BuddySpace &s) {
os << "1 unit = " << s.GetMinSize() << "\n"
<< "Size of buddy space = " << s.GetMaxSize() << "\n"
<< "Number of levels = " << s.num_lvl_ << "\n\n"
<< "Percent free = " << s.PercentFree() << "\n"
<< "Dumping count array : "
<< "\n";
for (int i = 0; i < s.num_lvl_; i++) {
os << "[" << i << "] = " << s.count_[i] << " ";
if (((i + 1) % 4) == 0) {
os << "\n";
}
}
os << "\n";
os << "Dumping allocation info:"
<< "\n";
auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, s.num_lvl_ - 1));
rel_addr_t addr = 0;
while (addr < max_addr) {
size_t sz = 0;
BuddySpace::STATE st;
s.GetBuddySegState(addr, &sz, &st);
os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : "
<< ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn"))
<< "\n";
addr += sz;
}
return os;
}
void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const {
char byte;
int pos;
int offset;
uint64_t val = 0;
int shift;
pos = BitRightShift(rel_addr, 2);
offset = rel_addr % 4;
shift = offset * 2;
byte = map_[pos];
switch (offset) {
case 0:
val = byte;
break;
case 1:
case 3:
if (offset == 1) {
val = BitLeftShift(BitAnd(byte, 0x30), shift);
} else {
val = BitLeftShift(BitAnd(byte, 0x03), shift);
}
break;
case 2:
val = BitLeftShift(BitAnd(byte, 0x0F), shift);
break;
}
if (BitAnd(val, ONE_BIT)) {
*rel_sz = 1;
} else if (BitAnd(val, TWO_BIT)) {
*rel_sz = 2;
} else if (BitAnd(val, MORE_BIT)) {
log_t lg = BitAnd(val, 0x0F);
*rel_sz = BitLeftShift(1, lg + 2);
} else {
*st = STATE::kEmpty;
return;
}
*st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree;
}
void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) {
int clr;
int mask;
int pos;
int offset;
int val = 0;
int shift;
auto log_sz = static_cast<log_t>(Log2(rel_sz));
pos = BitRightShift(rel_addr, 2);
offset = rel_addr % 4;
shift = offset * 2;
if (rel_sz == 1) {
val = ONE_BIT;
mask = 0xC0;
} else if (rel_sz == 2) {
val = TWO_BIT;
mask = 0xF0;
} else {
val = BitOr(log_sz - 2, MORE_BIT);
mask = 0xFF;
}
if (st == STATE::kAlloc) {
val = BitOr(val, ALLOC_BIT);
} else if (st == STATE::kFree) {
val = BitAnd(val, ~(static_cast<uint64_t>(ALLOC_BIT)));
} else if (st == STATE::kEmpty) {
val = 0;
}
clr = static_cast<int>(~(BitRightShift(mask, shift)));
map_[pos] = static_cast<char>(BitAnd(map_[pos], clr));
map_[pos] = static_cast<char>(BitOr(map_[pos], BitRightShift(val, shift)));
if (st == STATE::kAlloc) {
count_[log_sz]--;
} else if (st == STATE::kFree) {
count_[log_sz]++;
if (rel_addr < hint_[log_sz]) {
hint_[log_sz] = rel_addr;
}
}
}
void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) {
while (blk_sz < BitLeftShift(1, num_lvl_)) {
rel_addr_t buddy = BitEx(addr, blk_sz);
size_t sz = 0;
STATE st;
GetBuddySegState(buddy, &sz, &st);
if (st == STATE::kFree && sz == blk_sz) {
auto log_sz = static_cast<log_t>(Log2(blk_sz));
rel_addr_t left = (buddy < addr) ? buddy : addr;
rel_addr_t right = left + blk_sz;
DS_ASSERT(count_[log_sz] >= 2);
count_[log_sz] -= 2;
SetBuddySegState(right, blk_sz, STATE::kEmpty);
SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree);
for (int i = 0; i < log_sz; i++) {
if (hint_[i] == right) {
hint_[i] = left;
}
}
addr = left;
blk_sz <<= 1u;
} else {
break;
}
}
}
void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) {
DS_ASSERT(ask_sz < blk_sz);
uint32_t inx = Log2(blk_sz);
size_t remaining_sz = ask_sz;
for (int i = inx; i > 0; i--) {
size_t b_size = BitLeftShift(1, i);
size_t half_sz = BitRightShift(b_size, 1);
count_[i]--;
SetBuddySegState(addr, half_sz, STATE::kFree);
SetBuddySegState(addr + half_sz, half_sz, STATE::kFree);
if (remaining_sz >= half_sz) {
SetBuddySegState(addr, half_sz, STATE::kAlloc);
remaining_sz -= half_sz;
if (remaining_sz == 0) {
break;
}
addr += half_sz;
}
}
}
void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) {
DS_ASSERT(ask_sz < blk_sz);
uint32_t inx = Log2(blk_sz);
size_t remaining_sz = ask_sz;
for (int i = inx; i > 0; i--) {
size_t b_size = BitLeftShift(1, i);
size_t half_sz = BitRightShift(b_size, 1);
if (remaining_sz >= half_sz) {
#ifdef DEBUG
{
size_t sz = 0;
STATE st;
GetBuddySegState(addr, &sz, &st);
DS_ASSERT(sz == half_sz && st == STATE::kAlloc);
}
#endif
SetBuddySegState(addr, half_sz, STATE::kFree);
remaining_sz -= half_sz;
if (remaining_sz == 0) {
JoinBuddySeg(addr, half_sz);
break;
}
addr += half_sz;
}
}
}
rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept {
uint32_t blk_size = NextPowerOf2(req_size);
int start_inx = static_cast<int>(Log2(blk_size));
bool found = false;
rel_addr_t ask_addr = 0;
auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, num_lvl_ - 1));
STATE st;
size_t sz = 0;
for (int i = start_inx; !found && i < num_lvl_; i++) {
DS_ASSERT(count_[i] >= 0);
if (count_[i] == 0) {
continue;
}
auto blk_sz = static_cast<size_t>(BitLeftShift(1, i));
ask_addr = hint_[i];
while (ask_addr < max_addr && !found) {
GetBuddySegState(ask_addr, &sz, &st);
if (st == STATE::kFree && sz == blk_sz) {
found = true;
} else {
DS_ASSERT(st != STATE::kEmpty);
ask_addr += ((sz > blk_sz) ? sz : blk_sz);
}
}
}
if (found) {
if (sz > req_size) {
TrimBuddySeg(ask_addr, sz, req_size);
} else {
SetBuddySegState(ask_addr, sz, STATE::kAlloc);
hint_[start_inx] = ask_addr;
}
return ask_addr;
} else {
return static_cast<rel_addr_t>(NOSPACE);
}
}
void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) {
if (req_size == blk_size) {
#ifdef DEBUG
{
size_t sz = 0;
STATE st;
GetBuddySegState(addr, &sz, &st);
}
#endif
SetBuddySegState(addr, blk_size, STATE::kFree);
JoinBuddySeg(addr, blk_size);
} else {
UnTrimBuddySeg(addr, blk_size, req_size);
}
}
int BuddySpace::PercentFree() const {
uint64_t total_free_sz = 0;
uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1);
// Go through the count array without lock
for (int i = 0; i < num_lvl_; i++) {
int cnt = count_[i];
if (cnt == 0) {
continue;
}
uint64_t blk_sz = BitLeftShift(1, i);
total_free_sz += (blk_sz * cnt);
}
return static_cast<int>(static_cast<float>(total_free_sz) / static_cast<float>(max_sz_in_unit) * 100);
}
BuddySpace::BuddySpace(int log_min, int num_lvl)
: hint_(nullptr),
count_(nullptr),
map_(nullptr),
log_min_(log_min),
num_lvl_(num_lvl),
min_(0),
max_(0),
ptr_(nullptr) {}
BuddySpace::~BuddySpace() {
if (ptr_ != nullptr) {
free(ptr_);
}
hint_ = nullptr;
count_ = nullptr;
map_ = nullptr;
}
Status BuddySpace::CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min, int num_lvl) {
Status rc;
auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl);
if (bs == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
rc = bs->Init();
if (rc.IsOk()) {
(*out_bs).reset(bs);
} else {
delete bs;
}
return rc;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_BUDDY_H_
#define DATASET_UTIL_BUDDY_H_
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
#include "dataset/util/status.h"
using addr_t = int64_t;
using rel_addr_t = int32_t;
using log_t = int;
#define ALLOC_BIT 0x80
#define ONE_BIT 0x40
#define TWO_BIT 0x20
#define MORE_BIT 0x10
#define NOSPACE ((addr_t)(-1))
namespace mindspore {
namespace dataset {
struct BSpaceDescriptor {
int32_t sig;
rel_addr_t addr;
size_t req_size;
size_t blk_size;
};
class BuddySpace {
public:
// C++11 feature. Change STATE into a type safe class with
// the keyword. Don't take out the keyword 'class'
enum class STATE { kFree, kAlloc, kEmpty };
BuddySpace(const BuddySpace &) = delete;
BuddySpace &operator=(const BuddySpace &) = delete;
virtual ~BuddySpace();
Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept;
void Free(const BSpaceDescriptor *desc);
uint64_t GetMinSize() const { return min_; }
uint64_t GetMaxSize() const { return max_; }
int PercentFree() const;
friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s);
static uint64_t NextPowerOf2(uint64_t n) {
if (n <= 1) {
return 1;
}
n = n - 1;
while (n & (n - 1)) {
n = n & (n - 1);
}
return n << 1;
}
static uint32_t Log2(uint64_t n) {
uint32_t cnt = 0;
while (n >>= 1) {
cnt++;
}
return cnt;
}
static Status CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min = 15, int num_lvl = 18);
private:
rel_addr_t *hint_;
int *count_;
char *map_;
int log_min_;
int num_lvl_;
uint64_t min_;
uint64_t max_;
void *ptr_;
std::mutex mutex_;
explicit BuddySpace(int log_min = 15, int num_lvl = 18);
Status Init();
addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept;
void FreeNoLock(const BSpaceDescriptor *desc);
uint32_t SizeToBlock(const uint64_t sz) const {
uint32_t reqSize = (sz / min_);
if (sz % min_) {
reqSize++;
}
return reqSize;
}
void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const;
void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st);
void JoinBuddySeg(rel_addr_t addr, size_t blk_sz);
void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz);
void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz);
rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept;
void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size);
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_BUDDY_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "common/utils.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/services.h"
namespace mindspore {
namespace dataset {
CachePool::CachePool(const value_allocator &alloc, const std::string &root)
: alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}
Status CachePool::DoServiceStart() {
tree_ = std::make_shared<data_index>();
// If we are given a disk path, set up the StorageManager
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
RETURN_IF_NOT_OK(spill.CreateDirectories());
sm_ = std::make_shared<StorageManager>(spill);
RETURN_IF_NOT_OK(sm_->ServiceStart());
MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString());
}
return Status::OK();
}
Status CachePool::DoServiceStop() {
Status rc;
Status rc2;
if (sm_ != nullptr) {
rc = sm_->ServiceStop();
if (rc.IsError()) {
rc2 = rc;
}
}
sm_.reset();
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
}
}
tree_.reset();
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
auto it = Path::DirIterator::OpenDirectory(&spill);
while (it->hasNext()) {
rc = it->next().Remove();
if (rc.IsError() && rc2.IsOk()) {
rc2 = rc;
}
}
rc = spill.Remove();
if (rc.IsError() && rc2.IsOk()) {
rc2 = rc;
}
}
return rc2;
}
CachePool::~CachePool() noexcept { (void)ServiceStop(); }
Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) {
DataLocator bl;
Status rc;
size_t sz = 0;
// We will consolidate all the slices into one piece.
for (auto &v : buf) {
sz += v.GetSize();
}
bl.sz = sz;
try {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
}
pos += v.GetSize();
}
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
}
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
// We have an assumption 0 is not a valid key from the design of AutoIndexObj.
// Make sure it is not 0.
if (bl.storage_key == 0) {
RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected");
}
} else {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
rc = tree_->insert(bl, key);
if (rc.IsError() && bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, sz);
}
return rc;
}
Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
if (it->ptr != nullptr) {
ReadableSlice src(it->ptr, it->sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src));
} else if (sm_ != nullptr) {
size_t expectedLength = 0;
RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength));
if (expectedLength != it->sz) {
MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "."
<< " Internal key: " << key << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
}
if (bytesRead != nullptr) {
*bytesRead = it->sz;
}
} else {
RETURN_STATUS_UNEXPECTED("Key not found");
}
return Status::OK();
}
const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; }
Path CachePool::GetSpillPath() const {
auto spill = Path(root_) / subfolder_;
return spill;
}
CachePool::CacheStat CachePool::GetStat() const {
CacheStat cs{0};
for (auto &it : *tree_) {
if (it.ptr != nullptr) {
++cs.num_mem_cached;
} else {
++cs.num_disk_cached;
}
}
return cs;
}
Status CachePool::Spill(CachePool::DataLocator *dl) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to spill");
}
RETURN_UNEXPECTED_IF_NULL(dl);
RETURN_UNEXPECTED_IF_NULL(dl->ptr);
if (dl->storage_key == 0) {
ReadableSlice data(dl->ptr, dl->sz);
RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data}));
}
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return Status::OK();
}
Status CachePool::Locate(CachePool::DataLocator *dl) {
RETURN_UNEXPECTED_IF_NULL(dl);
if (dl->ptr == nullptr) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to locate the data");
}
try {
dl->ptr = alloc_.allocate(dl->sz);
WritableSlice dest(dl->ptr, dl->sz);
Status rc = Read(dl->storage_key, &dest);
if (rc.IsError()) {
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return rc;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
return Status::OK();
}
size_t CachePool::GetSize(CachePool::key_type key) const {
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
return it->sz;
} else {
return 0;
}
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_CACHE_POOL_H_
#define DATASET_UTIL_CACHE_POOL_H_
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_manager.h"
#include "dataset/util/auto_index.h"
namespace mindspore {
namespace dataset {
/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of
/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to
/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to
/// restore the buffer.
/// \see ReadableSlice
class CachePool : public Service {
public:
using base_type = uint8_t;
using pointer = base_type *;
using const_pointer = const base_type *;
using reference = base_type &;
using const_reference = const base_type &;
using value_allocator = Allocator<base_type>;
// An internal class to locate the whereabouts of a backed up buffer which can be either in
class DataLocator {
public:
DataLocator() : ptr(nullptr), sz(0), storage_key(0) {}
~DataLocator() = default;
DataLocator(const DataLocator &other) = default;
DataLocator &operator=(const DataLocator &other) = default;
DataLocator(DataLocator &&other) noexcept {
ptr = other.ptr;
sz = other.sz;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
other.storage_key = 0;
}
DataLocator &operator=(DataLocator &&other) noexcept {
if (&other != this) {
ptr = other.ptr;
sz = other.sz;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
other.storage_key = 0;
}
return *this;
}
pointer ptr;
size_t sz;
StorageManager::key_type storage_key;
};
using data_index = AutoIndexObj<DataLocator>;
using key_type = data_index::key_type;
using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other;
/// \brief Simple statistics returned from CachePool like how many elements are cached in memory and
/// how many elements are spilled to disk.
struct CacheStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
};
/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit CachePool(const value_allocator &alloc, const std::string &root = "");
CachePool(const CachePool &) = delete;
CachePool(CachePool &&) = delete;
CachePool &operator=(const CachePool &) = delete;
CachePool &operator=(CachePool &&) = delete;
~CachePool() noexcept;
Status DoServiceStart() override;
Status DoServiceStop() override;
Path GetSpillPath() const;
/// \brief Insert a sequence of ReadableSlice objects into the pool.
/// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk.
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[out] key Generated key
/// \return Error code
Status Insert(const std::vector<ReadableSlice> &buf, key_type *key);
/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
/// \param[out] bytesRead Optional. Number of bytes read.
/// \return Error code
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const;
Status Spill(DataLocator *dl);
Status Locate(DataLocator *dl);
size_t GetSize(key_type key) const;
/// \brief Get statistics.
/// \return CacheStat object
CacheStat GetStat() const;
const value_allocator &get_allocator() const;
std::string MyName() const { return subfolder_; }
private:
value_allocator alloc_;
Path root_;
const std::string subfolder_;
std::shared_ptr<StorageManager> sm_;
std::shared_ptr<data_index> tree_;
};
} // namespace dataset
} // namespace mindspore
#endif
......@@ -106,6 +106,24 @@ struct List {
++count;
}
// Insert elem2 before elem1 in the list.
virtual void InsertBefore(pointer elem1, pointer elem2) {
DS_ASSERT(elem1 != elem2);
Node<T> &elem1_node = elem1->*node;
Node<T> &elem2_node = elem2->*node;
elem2_node.next = elem1;
elem2_node.prev = elem1_node.prev;
if (elem1_node.prev != nullptr) {
Node<T> &prev_node = elem1_node.prev->*node;
prev_node.next = elem2;
}
elem1_node.prev = elem2;
if (head == elem1) {
head = elem2;
}
++count;
}
// Remove an element in the list
virtual void Remove(pointer elem) noexcept {
Node<T> &elem_node = elem->*node;
......
......@@ -44,20 +44,6 @@ class MemoryPool {
virtual ~MemoryPool() {}
};
// Used by unique_ptr
template <typename T>
class Deleter {
public:
explicit Deleter(std::shared_ptr<MemoryPool> &mp) : mp_(mp) {}
~Deleter() = default;
void operator()(T *ptr) const { mp_->Deallocate(ptr); }
private:
std::shared_ptr<MemoryPool> mp_;
};
Status DeMalloc(std::size_t s, void **p, bool);
} // namespace dataset
} // namespace mindspore
......
......@@ -16,6 +16,8 @@
#include "dataset/util/path.h"
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <new>
#include <sstream>
#include <utility>
......@@ -26,7 +28,7 @@
namespace mindspore {
namespace dataset {
#ifdef _WIN32
#if defined(_WIN32) || defined(_WIN64)
char Path::separator_ = '\\';
#else
char Path::separator_ = '/';
......@@ -132,7 +134,7 @@ Status Path::CreateDirectory() {
#if defined(_WIN32) || defined(_WIN64)
int rc = mkdir(common::SafeCStr(path_));
#else
int rc = mkdir(common::SafeCStr(path_), 0700);
int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR);
#endif
if (rc) {
std::ostringstream oss;
......@@ -182,6 +184,111 @@ Status Path::CreateDirectories() {
return Status::OK();
}
Status Path::Remove() {
if (Exists()) {
if (IsDirectory()) {
errno_t err = rmdir(common::SafeCStr(path_));
if (err == -1) {
std::ostringstream oss;
oss << "Unable to delete directory " << path_ << ". Errno = " << errno;
RETURN_STATUS_UNEXPECTED(oss.str());
}
} else {
errno_t err = unlink(common::SafeCStr(path_));
if (err == -1) {
std::ostringstream oss;
oss << "Unable to delete file " << path_ << ". Errno = " << errno;
RETURN_STATUS_UNEXPECTED(oss.str());
}
}
}
return Status::OK();
}
Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); }
Status Path::OpenFile(int *file_descriptor, bool create) {
int fd;
if (file_descriptor == nullptr) {
RETURN_STATUS_UNEXPECTED("null pointer");
}
if (IsDirectory()) {
std::ostringstream oss;
oss << "Unable to create file " << path_ << " which is a directory.";
RETURN_STATUS_UNEXPECTED(oss.str());
}
// Convert to canonical form.
if (strlen(common::SafeCStr(path_)) > PATH_MAX) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
char canonical_path[PATH_MAX + 1] = {0x00};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) {
#else
if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) {
#endif
if (errno == ENOENT && create) {
// File doesn't exist and we are to create it. Let's break it down.
auto file_part = Basename();
auto parent_part = ParentPath();
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) {
#else
if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) {
#endif
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto cur_inx = strlen(canonical_path);
if ((cur_inx + file_part.length() + 1) > PATH_MAX) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
canonical_path[cur_inx++] = separator_;
if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) !=
EOK) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
} else {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
}
if (create) {
fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP);
} else {
fd = open(canonical_path, O_RDWR);
}
if (fd == -1) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
*file_descriptor = fd;
return Status::OK();
}
Status Path::CloseFile(int fd) const {
if (close(fd) < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
return Status::OK();
}
Status Path::TruncateFile(int fd) const {
int rc;
rc = ftruncate(fd, 0);
if (rc == 0) {
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
}
std::string Path::Basename() {
std::size_t found = path_.find_last_of(separator_);
if (found != std::string::npos) {
return path_.substr(found + 1);
} else {
return path_;
}
}
std::shared_ptr<Path::DirIterator> Path::DirIterator::OpenDirectory(Path *f) {
auto it = new (std::nothrow) DirIterator(f);
......@@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() {
Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) {
MS_LOG(DEBUG) << "Open directory " << f->toString() << ".";
dp_ = opendir(common::SafeCStr(f->toString()));
dp_ = opendir(f->toString().c_str());
}
bool Path::DirIterator::hasNext() {
......@@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() {
}
Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); }
std::ostream &operator<<(std::ostream &os, const Path &s) {
os << s.path_;
return os;
}
} // namespace dataset
} // namespace mindspore
......@@ -90,6 +90,20 @@ class Path {
std::string ParentPath();
Status Remove();
Status CreateFile(int *fd);
Status OpenFile(int *fd, bool create = false);
Status CloseFile(int fd) const;
Status TruncateFile(int fd) const;
std::string Basename();
friend std::ostream &operator<<(std::ostream &os, const Path &s);
private:
static char separator_;
std::string path_;
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/semaphore.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
Status Semaphore::P() {
std::unique_lock<std::mutex> lck(mutex_);
RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; }));
--value_;
return Status::OK();
}
void Semaphore::V() {
std::unique_lock<std::mutex> lck(mutex_);
++value_;
wait_cond_.NotifyOne();
}
int Semaphore::Peek() {
std::unique_lock<std::mutex> lck(mutex_);
return value_;
}
Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); }
Status Semaphore::Deregister() { return (wait_cond_.Deregister()); }
void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); }
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SEMAPHORE_H_
#define DATASET_UTIL_SEMAPHORE_H_
#include "dataset/util/cond_var.h"
namespace mindspore {
namespace dataset {
class TaskGroup;
/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be
/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters.
class Semaphore {
public:
/// \brief Constructor
/// \param init Initial value of the internal counter.
explicit Semaphore(int init) : value_(init) {}
virtual ~Semaphore() {}
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status P();
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
void V();
/// \brief Peek the internal value
/// \return The internal value
int Peek();
Status Register(TaskGroup *vg);
Status Deregister();
void ResetIntrpState();
private:
int value_;
std::mutex mutex_;
CondVar wait_cond_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_SEMAPHORE_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/slice.h"
namespace mindspore {
namespace dataset {
WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) {
mutable_data_ = static_cast<char *>(src.mutable_data_) + offset;
}
WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset)
: WritableSlice(src, offset, src.GetSize() - offset) {}
Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) {
RETURN_UNEXPECTED_IF_NULL(dest);
RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer());
if (dest->GetSize() <= 0) {
RETURN_STATUS_UNEXPECTED("Destination length is non-positive");
}
auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize());
if (err) {
RETURN_STATUS_UNEXPECTED(std::to_string(err));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SLICE_H_
#define DATASET_UTIL_SLICE_H_
#include <unistd.h>
#include <cstddef>
#include <utility>
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief A ReadableSlice wraps a const pointer in memory and its size.
/// \see WritableSlice for a non-const version
///
class ReadableSlice {
public:
ReadableSlice() : ptr_(nullptr), sz_(0) {}
ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {}
ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) {
ptr_ = static_cast<const char *>(src.GetPointer()) + offset;
sz_ = len;
}
ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {}
ReadableSlice(const ReadableSlice &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
}
ReadableSlice &operator=(const ReadableSlice &lhs) {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
}
return *this;
}
ReadableSlice(ReadableSlice &&lhs) noexcept {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
lhs.ptr_ = nullptr;
lhs.sz_ = 0;
}
}
ReadableSlice &operator=(ReadableSlice &&lhs) noexcept {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
lhs.ptr_ = nullptr;
lhs.sz_ = 0;
}
return *this;
}
/// \brief Getter function
/// \return Const version of the pointer
const void *GetPointer() const { return ptr_; }
/// \brief Getter function
/// \return Size of the slice
size_t GetSize() const { return sz_; }
bool empty() const { return ptr_ == nullptr; }
private:
const void *ptr_;
size_t sz_;
};
/// \brief A WritableSlice inherits from ReadableSlice to allow
/// one to write to the address pointed to by the pointer.
///
class WritableSlice : public ReadableSlice {
public:
friend class StorageContainer;
/// \brief Default constructor
WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {}
/// \brief This form of a constructor takes a pointer and its size.
WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {}
WritableSlice(const WritableSlice &src, off64_t offset, size_t len);
WritableSlice(const WritableSlice &src, off64_t offset);
WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; }
WritableSlice &operator=(const WritableSlice &lhs) {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
ReadableSlice::operator=(lhs);
}
return *this;
}
WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
lhs.mutable_data_ = nullptr;
}
}
WritableSlice &operator=(WritableSlice &&lhs) noexcept {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
lhs.mutable_data_ = nullptr;
ReadableSlice::operator=(std::move(lhs));
}
return *this;
}
/// \brief Copy the content from one slice onto another.
static Status Copy(WritableSlice *dest, const ReadableSlice &src);
private:
void *mutable_data_;
void *GetMutablePointer() { return mutable_data_; }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_SLICE_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_container.h"
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "common/utils.h"
#include "dataset/util/de_error.h"
#include "dataset/util/path.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
Status StorageContainer::Create() {
RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_));
RETURN_IF_NOT_OK(cont_.CreateFile(&fd_));
is_open_ = true;
MS_LOG(INFO) << "Container " << cont_ << " created";
return Status::OK();
}
Status StorageContainer::Open() noexcept {
std::lock_guard<std::mutex> lck(mutex_);
// Check again
if (!is_open_) {
RETURN_IF_NOT_OK(cont_.OpenFile(&fd_));
is_open_ = true;
}
return Status::OK();
}
Status StorageContainer::Close() noexcept {
if (is_open_) {
std::lock_guard<std::mutex> lck(mutex_);
// Check again
if (is_open_) {
RETURN_IF_NOT_OK(cont_.CloseFile(fd_));
is_open_ = false;
fd_ = -1;
}
}
return Status::OK();
}
Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept {
DS_ASSERT(is_open_);
RETURN_UNEXPECTED_IF_NULL(dest);
auto sz = dest->GetSize();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pread64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std::lock_guard<std::mutex> lck(mutex_);
auto seek_err = lseek(fd_, offset, SEEK_SET);
if (seek_err < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto r_sz = read(fd_, dest->GetMutablePointer(), sz);
#else
auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset);
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
}
return Status::OK();
}
Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept {
DS_ASSERT(is_open_);
auto sz = dest.GetSize();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pwrite64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std::lock_guard<std::mutex> lck(mutex_);
auto seek_err = lseek(fd_, offset, SEEK_SET);
if (seek_err < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto r_sz = write(fd_, dest.GetPointer(), sz);
#else
auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset);
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
}
return Status::OK();
}
Status StorageContainer::Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept {
size_t sz = 0;
for (auto &v : buf) {
sz += v.GetSize();
}
if (sz == 0) {
RETURN_STATUS_UNEXPECTED("Unexpected 0 length");
}
if (sz > bs_->GetMaxSize()) {
RETURN_STATUS_UNEXPECTED("Request size too big");
}
BSpaceDescriptor bspd{0};
addr_t addr = 0;
RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr));
*offset = static_cast<off64_t>(addr);
// We will do piecewise copy of the data to disk.
for (auto &v : buf) {
RETURN_IF_NOT_OK(Write(v, addr));
addr += v.GetSize();
}
return Status::OK();
}
Status StorageContainer::Truncate() const noexcept {
if (is_open_) {
RETURN_IF_NOT_OK(cont_.TruncateFile(fd_));
MS_LOG(INFO) << "Container " << cont_ << " truncated";
}
return Status::OK();
}
StorageContainer::~StorageContainer() noexcept {
(void)Truncate();
(void)Close();
}
std::ostream &operator<<(std::ostream &os, const StorageContainer &s) {
os << "File path : " << s.cont_ << "\n" << *(s.bs_.get());
return os;
}
Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path) {
Status rc;
auto sc = new (std::nothrow) StorageContainer(path);
if (sc == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
rc = sc->Create();
if (rc.IsOk()) {
(*out_sc).reset(sc);
} else {
delete sc;
}
return rc;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_
#define DATASET_UTIL_STORAGE_CONTAINER_H_
#include <limits.h>
#include <unistd.h>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/system_pool.h"
#include "dataset/util/buddy.h"
#include "dataset/util/path.h"
#include "dataset/util/slice.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class StorageManager;
class StorageContainer {
public:
friend class StorageManager;
~StorageContainer() noexcept;
StorageContainer(const StorageContainer &) = delete;
StorageContainer &operator=(const StorageContainer &) = delete;
friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s);
Status Open() noexcept;
Status Close() noexcept;
Status Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept;
Status Write(const ReadableSlice &dest, off64_t offset) const noexcept;
Status Read(WritableSlice *dest, off64_t offset) const noexcept;
Status Truncate() const noexcept;
bool IsOpen() const { return is_open_; }
static Status CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path);
private:
mutable std::mutex mutex_;
Path cont_;
int fd_;
bool is_open_;
std::unique_ptr<BuddySpace> bs_;
// Use the default value of BuddySpace
// which can map upto 4G of space.
explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {}
Status Create();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_STORAGE_CONTAINER_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_manager.h"
#include <iomanip>
#include <sstream>
#include <stdexcept>
#include <utility>
#include "common/utils.h"
#include "dataset/util/path.h"
#include "dataset/util/services.h"
#include "dataset/util//de_error.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) {
std::ostringstream oss;
oss << prefix << std::setfill('0') << std::setw(5) << file_id;
return oss.str();
}
std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) {
std::string base_name = GetBaseName(prefix, file_id);
return (base_name + "." + suffix);
}
Status StorageManager::AddOneContainer() {
const std::string kPrefix = "IMG";
const std::string kSuffix = "LB";
Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix);
std::shared_ptr<StorageContainer> sc;
RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString()));
containers_.push_back(sc);
file_id_++;
return Status::OK();
}
Status StorageManager::DoServiceStart() {
containers_.reserve(1000);
if (root_.IsDirectory()) {
RETURN_IF_NOT_OK(AddOneContainer());
} else {
RETURN_STATUS_UNEXPECTED("Not a directory");
}
return Status::OK();
}
Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &buf) {
RETURN_UNEXPECTED_IF_NULL(key);
size_t sz = 0;
for (auto &v : buf) {
sz += v.GetSize();
}
if (sz == 0) {
RETURN_STATUS_UNEXPECTED("Unexpected 0 length");
}
std::shared_ptr<StorageContainer> cont;
key_type out_key;
value_type out_value;
bool create_new_container = false;
do {
SharedLock lock_s(&rw_lock_);
size_t num_containers = containers_.size();
if (create_new_container) {
// Upgrade to exclusvie lock.
lock_s.Upgrade();
create_new_container = false;
// Check again if someone has already added a
// new container after we got the x lock
if (containers_.size() == num_containers) {
RETURN_IF_NOT_OK(AddOneContainer());
}
// Refresh how many containers there are.
num_containers = containers_.size();
// Downgrade back to shared lock
lock_s.Downgrade();
}
if (num_containers == 0) {
RETURN_STATUS_UNEXPECTED("num_containers is zero");
}
// Go to the last container to insert.
cont = containers_.at(num_containers - 1);
off64_t offset;
Status rc = cont->Insert(buf, &offset);
if (rc.IsNoSpace()) {
create_new_container = true;
} else if (rc.IsOk()) {
out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz));
RETURN_IF_NOT_OK(index_.insert(out_value, &out_key));
*key = out_key;
break;
} else {
return rc;
}
} while (true);
return Status::OK();
}
Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = index_.Search(key);
if (r.second) {
auto &it = r.first;
value_type v = *it;
int container_inx = v.first;
off_t offset = v.second.first;
size_t sz = v.second.second;
if (dest->GetSize() < sz) {
std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) +
" but length = " + std::to_string(dest->GetSize());
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (bytesRead != nullptr) {
*bytesRead = sz;
}
auto cont = containers_.at(container_inx);
RETURN_IF_NOT_OK(cont->Read(dest, offset));
} else {
RETURN_STATUS_UNEXPECTED("Key not found");
}
return Status::OK();
}
Status StorageManager::DoServiceStop() noexcept {
Status rc;
Status rc1;
for (auto const &p : containers_) {
// The destructor of StorageContainer is not called automatically until the use
// count drops to 0. But it is not always the case. We will do it ourselves.
rc = p.get()->Truncate();
if (rc.IsError()) {
rc1 = rc;
}
}
containers_.clear();
file_id_ = 0;
return rc1;
}
StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {}
StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); }
std::ostream &operator<<(std::ostream &os, const StorageManager &s) {
os << "Dumping all containers ..."
<< "\n";
for (auto const &p : s.containers_) {
os << *(p.get());
}
return os;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_MANAGER_H_
#define DATASET_UTIL_STORAGE_MANAGER_H_
#include <unistd.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/auto_index.h"
#include "dataset/util/lock.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/path.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_container.h"
using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>;
namespace mindspore {
namespace dataset {
class StorageManager : public Service {
public:
using storage_index = AutoIndexObj<std::pair<int, std::pair<off_t, size_t>>>;
using key_type = storage_index::key_type;
using value_type = storage_index::value_type;
explicit StorageManager(const Path &);
~StorageManager() override;
StorageManager(const StorageManager &) = delete;
StorageManager &operator=(const StorageManager &) = delete;
Status Write(key_type *out_key, const std::vector<ReadableSlice> &buf);
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const;
Status DoServiceStart() override;
Status DoServiceStop() noexcept override;
friend std::ostream &operator<<(std::ostream &os, const StorageManager &s);
private:
Path root_;
ListOfContainers containers_;
int file_id_;
RWLock rw_lock_;
storage_index index_;
std::string GetBaseName(const std::string &prefix, int32_t file_id);
std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix);
Status AddOneContainer();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_STORAGE_MANAGER_H_
......@@ -19,8 +19,10 @@
#include <cstddef>
#include <cstdlib>
#include <limits>
#include <memory>
#include <new>
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/memory_pool.h"
namespace mindspore {
......@@ -61,6 +63,11 @@ class SystemPool : public MemoryPool {
uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); }
int PercentFree() const override { return 100; }
template <typename T>
static Allocator<T> GetAllocator() {
return Allocator<T>(std::make_shared<SystemPool>());
}
};
} // namespace dataset
} // namespace mindspore
......
......@@ -30,6 +30,7 @@
#include "kernel/common_utils.h"
#include "kernel/oplib/oplib.h"
#include "ir/value.h"
#include "pre_activate/common/helper.h"
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
......@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
}
}
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel);
......@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
}
auto is_all_nop_node = opt::IsAllNopNode(&graph);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
} else {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
......@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_inputs->emplace_back(input);
}
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
} else {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr output = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(output);
output->addr = device_address->ptr_;
......@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_outputs->emplace_back(output);
}
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace);
......@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
......
......@@ -96,8 +96,8 @@ class KernelRuntime {
private:
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
bool LaunchKernelMod(const session::KernelGraph &graph);
void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);
......
......@@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h"
#include "optimizer/opt.h"
namespace mindspore {
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
} // namespace opt
class OptimizerCaller {
public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
};
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
......@@ -23,6 +23,7 @@
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "kernel/akg/akg_kernel_metadata.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace kernel {
......@@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
kernel_type = KernelType::AKG_KERNEL;
}
switch (kernel_type) {
case KernelType::AKG_KERNEL:
AkgMetadataInfo(kernel_node, kernel_info_list);
......
......@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed;
}
// The op like print, summary, or the op do not has true output, and always as a depend node input.
static bool HasSideEffect(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) {
return false;
}
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
return GetValue<bool>(side_effect_v);
}
return false;
}
// If true do not merge the node.
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) {
return false;
}
// if has random effect, when generate by different op (not same object), do not merge.
if (prim_main != nullptr) {
if (prim_main == prim_node) {
return false;
}
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val);
......@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
return has_random_effect;
}
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
bool replace = false;
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node);
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
} else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>();
// When appsame is true, check if has side effect, do not merge.
if (check_side_effect && HasSideEffect(main)) {
return false;
}
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() == inp2.size()) {
bool appsame = true;
for (size_t j = 0; j < inp1.size(); j++) {
MS_EXCEPTION_IF_NULL(inp1[j]);
MS_EXCEPTION_IF_NULL(inp2[j]);
if (!(*inp1[j] == *inp2[j])) {
// Handle the case of two different Tensor, but with the same value
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) {
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]);
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 0; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
// Handle the case of two different Tensor, but with the same value
if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
// When the same side effect node as another two nodes' inputs, we still merge the node.
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
// node.
if (CheckReplace(inp1_j, inp2_j, false)) {
continue;
}
appsame = false;
break;
}
return false;
}
if (CheckRandomEffect(c_main, c_node)) {
appsame = false;
}
replace = appsame;
}
// When appsame is true, check if has random effect do not merge
if (CheckRandomEffect(c_main, c_node)) {
return false;
}
return true;
}
return replace;
// a parameter node.
return false;
}
bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
......
......@@ -41,7 +41,7 @@ class CSE {
return chg && report_changes_;
}
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const;
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
......
......@@ -14,140 +14,154 @@
* limitations under the License.
*/
#include "optimizer/irpass.h"
#include <string>
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"
namespace mindspore {
namespace opt {
namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
zero_like_fill_zero_ =
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ =
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate
item_tuple_eliminate_ =
MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose);
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ =
MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
reduce_eliminate_ = MakeSubstitution(
ReduceOneEliminater(), "reduce_eliminate",
std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
// Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ =
MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
// Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
// Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
// branch culling
switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ =
MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
"float_tuple_getitem_switch", prim::kPrimTupleGetItem);
float_env_getitem_switch_ =
MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup);
MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ =
MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);
// Addn
merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN);
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
// inline
inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph);
replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
replace_applicator_ =
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
// Incorporation
incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ =
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
"incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
// Virtual Dataset
virtual_dataset_eliminate_ =
MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel);
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel);
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
// Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
}
ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt
......
......@@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
......@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr all_reduce_fg_{nullptr};
};
class ArithmeticSimplify {
class ArithmeticSimplify : public OptimizerCaller {
public:
ArithmeticSimplify()
: multiply_by_zero_or_one_(),
tensor_multiply_by_one_(),
add_by_zero_(),
tensor_add_by_zero_(),
identity_(prim::kPrimIdentity),
opt_update_zero_tensor_(),
constant_duplicate_mul_(),
power_one_() {
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
......@@ -761,10 +762,10 @@ class ArithmeticSimplify {
}
~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
......@@ -773,15 +774,9 @@ class ArithmeticSimplify {
}
private:
MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByOne tensor_multiply_by_one_;
AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_,
opt_update_zero_tensor_, constant_duplicate_mul_, power_one_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// Arithmetic Simplifications should be done after step_parallel.
......@@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 {
class ArithmeticSimplify2 : public OptimizerCaller {
public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); }
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
......@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
}
private:
TensorMultiplyByZero tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt
......
......@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
namespace mindspore {
namespace opt {
......@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, t_{nullptr};
};
class CastEliminater {
class CastEliminater : public OptimizerCaller {
public:
CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {}
~CastEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = cast_same_type_eliminater_(optimizer, node);
if (new_node != nullptr) {
return new_node;
......
......@@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h"
namespace mindspore {
......@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool is_match_{false};
};
class EnvGetItemEliminater {
class EnvGetItemEliminater : public OptimizerCaller {
public:
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() {
EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()) {
eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_);
}
~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
......@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
}
private:
NewEnvGetItem new_env_get_item_;
AddEnvGetItem add_env_get_item_;
EnvGetSetItem env_get_set_item_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
......
......@@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
......@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal::GetitemTransform getitem_transform_;
};
class IncorporateGetitemSet {
class IncorporateGetitemSet : public OptimizerCaller {
public:
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() {
IncorporateGetitemSet()
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_);
}
~IncorporateGetitemSet() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
......@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
}
private:
IncorporateGetitem incorporate_getitem_;
IncorporateGetitemSwitch incorporate_getitem_switch_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt
......
......@@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
......@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
};
class ItemTupleEliminater {
class ItemTupleEliminater : public OptimizerCaller {
public:
ItemTupleEliminater()
: get_item_eliminater_(),
get_item_const_eliminater_(),
set_item_eliminater_(),
get_set_item_eliminater_(),
get_item_depend_reorder_() {
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_);
......@@ -277,10 +279,10 @@ class ItemTupleEliminater {
}
~ItemTupleEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
......@@ -289,12 +291,9 @@ class ItemTupleEliminater {
}
private:
GetitemEliminater get_item_eliminater_;
GetitemConstEliminater get_item_const_eliminater_;
SetitemEliminater set_item_eliminater_;
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
get_item_depend_reorder_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt
......
......@@ -19,9 +19,9 @@
#include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
......
......@@ -19,11 +19,12 @@
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "pipeline/static_analysis/dshape.h"
namespace mindspore {
......@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, shape_{nullptr};
};
class ReshapeEliminater {
class ReshapeEliminater : public OptimizerCaller {
public:
ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {}
~ReshapeEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = reshape_same_shape_eliminater_(optimizer, node);
if (new_node != nullptr) {
return new_node;
......
......@@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#include <securec.h>
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
class SpecialOpEliminater {
class SpecialOpEliminater : public OptimizerCaller {
public:
SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf),
stop_gradient_(prim::kPrimStopGradient),
hook_backward_(prim::kPrimHookBackward),
print_shape_type_(prim::kPrimPrintShapeType),
get_ref_value_(prim::kPrimGetRefValue),
mirror_(prim::kPrimMirror),
virtual_div_(prim::kPrimVirtualDiv) {
: insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(stop_gradient_);
eliminaters_.emplace_back(hook_backward_);
......@@ -53,10 +53,10 @@ class SpecialOpEliminater {
}
~SpecialOpEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
......@@ -65,9 +65,9 @@ class SpecialOpEliminater {
}
private:
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
virtual_div_;
std::vector<TransformFuncType> eliminaters_{};
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// {PrimVirtualDataset, X} -> X
......
......@@ -16,28 +16,27 @@
#include "optimizer/opt.h"
#include <algorithm>
#include <deque>
#include <memory>
#include <unordered_set>
#include <deque>
#include <algorithm>
#include "ir/anf.h"
#include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/log_adapter.h"
#include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
......@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
}
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
AnfNodePtr result = transform_(optimizer, node);
AnfNodePtr result = (*transform_)(optimizer, node);
#ifdef ENABLE_PROFILE
if (optimizer != nullptr) {
auto time = GetTime();
......
......@@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#include <vector>
#include <string>
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
......@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class Substitution {
public:
TransformFuncType transform_{nullptr};
OptimizerCallerPtr transform_;
std::string name_;
PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_;
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
};
using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
class SubstitutionList {
......
......@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
TensorRedistribution tensor_redistribution;
TensorRedistribution tensor_redistribution(false, true);
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
}
......@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
TensorRedistribution tensor_redistribution;
TensorRedistribution tensor_redistribution(false, true);
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
}
......
......@@ -62,6 +62,7 @@ void ParallelContext::Reset() {
enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
}
void ParallelContext::set_device_num(int32_t device_num) {
......
......@@ -100,6 +100,11 @@ class ParallelContext {
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer;
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void Reset();
private:
......@@ -123,6 +128,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
bool enable_parallel_optimizer_;
};
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
......
......@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
"Get enable/disable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
......
......@@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
}
} // namespace
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
......
......@@ -31,7 +31,7 @@ class BackendCSE : public CSE {
public:
BackendCSE() = default;
~BackendCSE() override = default;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
};
} // namespace opt
} // namespace mindspore
......
......@@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
} // namespace mindspore
......@@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];
} // namespace mindspore
#endif // PYBIND_API_EXPORT_FLAGS_H_
......@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore {
namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
return node;
}
}
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
return nullptr;
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
......@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
}
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
NOT_NULL(parameter));
}
}
}
......@@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
}
}
kg->SetExecOrderByDefault();
MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
return NOT_NULL(start_label);
}
......@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
return {partial_cnode, branch_kg};
}
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
......@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
<< to_outputs.size() << "]";
}
for (size_t i = 0; i < from_outputs.size(); i++) {
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
}
}
}
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
return;
return nullptr;
}
if (from.get() == to.get()) {
return;
return nullptr;
}
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString();
......@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node));
return assign_node;
}
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
......
......@@ -52,8 +52,9 @@ class AscendControlParser {
const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
......
......@@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes;
}
// Find control_depend real input nodes.
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
// update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) {
......@@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (depend_node->isa<Parameter>() && depend_mode == 1) {
depend_nodes = GetOutputNodes(depend_node);
}
for (auto &first_node : prior_nodes) {
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
for (auto &first_node : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue;
}
for (auto &second_node : depend_nodes) {
for (auto &second_node : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue;
}
......
......@@ -33,9 +33,14 @@
namespace py = pybind11;
namespace mindspore::inference {
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
inference::Session::RegAllOp();
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
try {
inference::Session::RegAllOp();
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed";
return nullptr;
}
}
void ExitInference() {
......@@ -51,12 +56,17 @@ void ExitInference() {
}
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
auto session = std::make_shared<inference::Session>();
auto ret = session->Init(device, device_id);
if (ret != 0) {
try {
auto session = std::make_shared<inference::Session>();
auto ret = session->Init(device, device_id);
if (ret != 0) {
return nullptr;
}
return session;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CreatSession failed";
return nullptr;
}
return session;
}
void Session::RegAllOp() {
......@@ -113,47 +123,71 @@ void Session::RegAllOp() {
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
return graph_id;
try {
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
return graph_id;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CompileGraph failed";
return static_cast<uint32_t>(-1);
}
}
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
std::vector<tensor::TensorPtr> inTensors;
inTensors.resize(inputs.size());
bool has_error = false;
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
if (tensor_ptr == nullptr) {
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
has_error = true;
return nullptr;
}
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
if (tensor == nullptr) {
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
has_error = true;
return nullptr;
}
return tensor->tensor();
});
if (has_error) {
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
return multiTensor;
}
VectorRef outputs;
session_impl_->RunGraph(graph_id, inTensors, &outputs);
try {
std::vector<tensor::TensorPtr> inTensors;
inTensors.resize(inputs.size());
bool has_error = false;
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
if (tensor_ptr == nullptr) {
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
has_error = true;
return nullptr;
}
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
if (tensor == nullptr) {
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
has_error = true;
return nullptr;
}
return tensor->tensor();
});
if (has_error) {
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
return multiTensor;
}
VectorRef outputs;
session_impl_->RunGraph(graph_id, inTensors, &outputs);
return TransformVectorRefToMultiTensor(outputs);
return TransformVectorRefToMultiTensor(outputs);
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference Rungraph failed";
return MultiTensor();
}
}
namespace {
string AjustTargetName(const std::string &device) {
if (device == kAscendDevice) {
return std::string(kAscendDevice) + "Inference";
} else {
MS_LOG(ERROR) << "Only support device Ascend right now";
return "";
}
}
} // namespace
int Session::Init(const std::string &device, uint32_t device_id) {
RegAllOp();
auto ms_context = MsContext::GetInstance();
ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_target(kAscendDevice);
session_impl_ = session::SessionFactory::Get().Create(device);
ms_context->set_device_id(device_id);
auto ajust_device = AjustTargetName(device);
if (ajust_device == "") {
return -1;
}
ms_context->set_device_target(device);
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
return -1;
......
......@@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
auto address = AnfAlgo::GetOutputAddr(node, output_index);
DeviceAddressPtr address;
auto is_all_nop_node = opt::IsAllNopNode(&graph);
if (is_all_nop_node) {
// The graph does not remove the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
} else {
// The graph removes the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
}
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
TypeId type_id = kNumberTypeFloat32;
......@@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
tensor->set_device_address(address);
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
......
......@@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
}
if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(WARNING) << "Control depend node's src or dest node is not a apply node, ignore it";
MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it";
error_ = SUCCESS;
}
return true;
......@@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
});
} else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) {
control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]});
} else if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it";
} else {
MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size()
<< " -> dst:" << dst_ops_list->size();
......
......@@ -463,7 +463,7 @@ void InitSubModulesLogLevel() {
// set submodule's log level
auto submodule = GetEnv("MS_SUBMODULE_LOG_v");
MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
LogConfigParser parser(submodule);
auto configs = parser.Parse();
for (const auto &cfg : configs) {
......@@ -489,22 +489,14 @@ void InitSubModulesLogLevel() {
} // namespace mindspore
extern "C" {
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
__attribute__((constructor)) void mindspore_log_init(void) {
__attribute__((constructor)) void common_log_init(void) {
#else
void mindspore_log_init(void) {
void common_log_init(void) {
#endif
#ifdef USE_GLOG
// do not use glog predefined log prefix
FLAGS_log_prefix = false;
static bool is_glog_initialzed = false;
if (!is_glog_initialzed) {
#if !defined(_WIN32) && !defined(_WIN64)
google::InitGoogleLogging("mindspore");
#endif
is_glog_initialzed = true;
}
// set default log level to WARNING
if (mindspore::GetEnv("GLOG_v").empty()) {
FLAGS_v = mindspore::WARNING;
......@@ -525,4 +517,22 @@ void mindspore_log_init(void) {
#endif
mindspore::InitSubModulesLogLevel();
}
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
__attribute__((constructor)) void mindspore_log_init(void) {
#else
void mindspore_log_init(void) {
#endif
#ifdef USE_GLOG
static bool is_glog_initialzed = false;
if (!is_glog_initialzed) {
#if !defined(_WIN32) && !defined(_WIN64)
google::InitGoogleLogging("mindspore");
#endif
is_glog_initialzed = true;
}
#endif
common_log_init();
}
}
......@@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
// index define of depend
constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2;
constexpr auto kDependInputSize = 3;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
......
......@@ -22,6 +22,10 @@ from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)
class Tensor(Tensor_):
......@@ -54,6 +58,10 @@ class Tensor(Tensor_):
"""
def __init__(self, input_data, dtype=None):
# If input data is numpy number, convert it to np array
if isinstance(input_data, np_types):
input_data = np.array(input_data)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int))
if dtype is not None:
......
......@@ -1040,7 +1040,7 @@ class Dataset:
Args:
columns (list[str], optional): List of columns to be used to specify the order of columns
(defaults=None, means all columns).
(default=None, means all columns).
Returns:
Iterator, list of ndarray.
......@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset):
class_indexing (dict, optional): A str-to-int mapping from label name to index
(default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
decode (bool, optional): decode the images after reading (defaults=False).
decode (bool, optional): decode the images after reading (default=False).
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
......@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset:
def process_dict(self, input_data):
"""
Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first.
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
"""
# Convert pandas like dict(has "values" column) into General dict
data_keys = list(input_data.keys())
......
......@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp):
Flip the input image horizontally, randomly with a given probability.
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""
@check_prob
......@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
Maintains data integrity by also flipping bounding boxes in an object detection pipeline.
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""
@check_prob
......@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
Flip the input image vertically, randomly with a given probability.
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""
@check_prob
......
......@@ -29,8 +29,9 @@ from .optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
"""
Update parameters.
......@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param))))
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m))))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_v
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
......@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.
Examples:
>>> net = Net()
......@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
def construct(self, gradients):
lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
return updated_velocity
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayDynamicLR(Optimizer):
......@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.
Examples:
>>> net = Net()
......@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step
return updated_velocity
return optim_result
......@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
gradient, decay_flag):
gradient, decay_flag, optim_filter):
"""
Update parameters.
......@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(
next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))
return next_v
if optim_filter:
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, next_param))
next_param = F.depend(next_param, F.assign(m, next_m))
next_param = F.depend(next_param, F.assign(v, next_v))
return next_param
return gradient
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
......@@ -238,7 +237,7 @@ class Lamb(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.
Examples:
>>> net = Net()
......@@ -311,18 +310,21 @@ class Lamb(Optimizer):
self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
if self.enable_graph_kernel:
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
else:
updated_velocity = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step
return updated_velocity
return optim_result
......@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode
__all__ = ['Optimizer']
......@@ -155,6 +158,27 @@ class Optimizer(Cell):
self.param_length = len(self.parameters)
self.map_ = C.Map()
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
def decay_weight(self, gradients):
"""
Weight decay.
......@@ -219,8 +243,32 @@ class Optimizer(Cell):
raise TypeError("Learning rate should be float, Tensor or Iterable.")
return lr
def _check_group_params(self, parameters):
"""Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
for group_param in parameters:
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
if invalid_key:
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
if 'order_params' in group_param.keys():
if len(group_param.keys()) > 1:
raise ValueError("The order params dict in group parameters should "
"only include the 'order_params' key.")
if not isinstance(group_param['order_params'], Iterable):
raise TypeError("The value of 'order_params' should be an Iterable type.")
continue
if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")
for param in group_param['params']:
if not isinstance(param, Parameter):
raise TypeError("The group param should be an iterator of Parameter type.")
def _parse_group_params(self, parameters, learning_rate):
"""Parse group params."""
self._check_group_params(parameters)
if self.dynamic_lr:
dynamic_lr_length = learning_rate.size()
else:
......@@ -250,9 +298,6 @@ class Optimizer(Cell):
if dynamic_lr_length not in (lr_length, 0):
raise ValueError("The dynamic learning rate in group should be the same size.")
if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")
dynamic_lr_length = lr_length
self.dynamic_lr_length = dynamic_lr_length
......@@ -384,6 +429,51 @@ class Optimizer(Cell):
lr = self.learning_rate
return lr
def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.
Returns:
tuple, the group id tuple of parameters.
"""
rank_list = ()
count = 0
for _ in range(self.param_length):
rank_list = rank_list + (count,)
count = count + 1
if count == self.dev_num:
count = 0
return rank_list
def broadcast_params(self, optim_result):
"""
Apply Broadcast operations in the sequential order of parameter groups.
Returns:
bool, the status flag.
"""
param_group = []
key_group = []
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
key_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = []
for root in range(self.dev_num):
ops = P.Broadcast(root)
next_params = ops(param_group[root])
new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i])
status = True
for i in range(self.dev_num - 1):
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
return status
def construct(self, *hyper_params):
raise NotImplementedError
......
......@@ -220,7 +220,9 @@ class DataWrapper(Cell):
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network
......
......@@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg
from .less import _less_akg
from .log import _log_akg
from .matmul import _matmul_akg
from .batchmatmul import _batchmatmul_akg
from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg
from .max_pool_with_argmax import _max_pool_with_argmax_akg
from .max import _max_akg
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "BatchMatMul",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"attr": [
{
"name": "transpose_a",
"param_type": "optional",
"type": "bool"
},
{
"name": "transpose_b",
"param_type": "optional",
"type": "bool"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x1"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "output"
}
]
}""")
def _batchmatmul_akg():
"""BatchMatMul AKG register"""
return
......@@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
.attr("transpose_first", "required", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()
......
......@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "Tuple")
def _list_setitem_with_Tuple(data, number_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (list): Value given.
Outputs:
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value):
"""
......
......@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
self.op = op
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)
def vm_impl(self, x):
"""Implement by vm mode."""
......
......@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
Output tensor or string to stdout.
Note:
The print operation cannot support the following cases currently.
1. The type of tensor is float64 or bool.
2. The data of tensor is a scalar type.
In pynative mode, please use python print function.
Inputs:
......@@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
pass
self.add_prim_attr("_side_effect", True)
def __call__(self, *args):
for arg in args:
......
......@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer):
def infer_value(self, input_x):
if input_x is not None:
input_x = input_x.asnumpy()
return Tensor(-input_x)
out = np.array(-input_x, input_x.dtype)
return Tensor(out)
return None
......@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
return Tensor(x / y)
out = np.array(x / y, x.dtype)
return Tensor(out)
return None
......
......@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
return variable
def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
# Add a type validation later when we don't have to assign a value to RefKey.
return variable
......
......@@ -400,6 +400,23 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_global_rank_is_set()
def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
"""
Set enable/disable parallel optimizer.
Args:
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool):
raise TypeError('enable_parallel_optimizer is invalid type')
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
def get_enable_parallel_optimizer(self):
"""Get parallel optimizer flag."""
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()
def reset(self):
"""Reset all settings."""
self.check_context_handle()
......@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch}
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}
_get_auto_parallel_context_func_map = {
......@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch}
"full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool)
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def _set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
......@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.
......@@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
"""
auto_parallel_context().reset()
......@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
self._dataset_sink_mode = True
def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir)
return self
......@@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value:
# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
if not self._is_collect_this_step(cb_params):
return
self._first_step = False
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
......@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params)
self._collect_histogram(cb_params)
self._first_step = False
self._record.record(cb_params.cur_step_num)
def end(self, run_context):
......@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod
def _package_custom_lineage_data(custom_lineage_data):
"""
......
......@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training):
else:
input_data = resize_column(*input_data)
photo = (np.random.rand() < config.photo_ratio)
if photo:
input_data = photo_crop_column(*input_data)
input_data = image_bgr_rgb(*input_data)
output_data = input_data
......@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
writer.write_raw_data([row])
writer.commit()
def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=8):
is_training=True, num_parallel_workers=4):
"""Creatr FasterRcnn dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
num_parallel_workers=1, shuffle=is_training)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
ds = ds.map(input_columns=["image"], operations=decode, num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_)
......@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
columns_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func, num_parallel_workers=4)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
flip = (np.random.rand() < config.flip_ratio)
if flip:
ds = ds.map(input_columns=["image"], operations=[horizontally_op],
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
operations=flipped_generation, num_parallel_workers=4)
operations=flipped_generation, num_parallel_workers=num_parallel_workers)
else:
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
else:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
......@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
operations=compose_map_func,
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
# transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2])
......
......@@ -19,7 +19,9 @@ from easydict import EasyDict as edict
cifar_cfg = edict({
'num_classes': 10,
'lr_init': 0.05,
'lr_init': 0.01,
'lr_max': 0.1,
'warmup_epochs': 5,
'batch_size': 64,
'epoch_size': 70,
'momentum': 0.9,
......
......@@ -38,20 +38,25 @@ random.seed(1)
np.random.seed(1)
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""Set learning rate."""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
warmup_steps = steps_per_epoch * warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr_each_step.append(lr_max)
elif i < decay_epoch_index[1]:
lr_each_step.append(lr_max * 0.1)
elif i < decay_epoch_index[2]:
lr_each_step.append(lr_max * 0.01)
if i < warmup_steps:
lr_value = float(lr_init) + inc_each_step * float(i)
else:
lr_each_step.append(lr_max * 0.001)
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr_value = float(lr_max) * base * base
if lr_value < 0.0:
lr_value = 0.0
lr_each_step.append(lr_value)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
......@@ -86,7 +91,8 @@ if __name__ == '__main__':
if args_opt.pre_trained:
load_param_into_net(net, load_checkpoint(args_opt.pre_trained))
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
......
......@@ -22,6 +22,7 @@
#include <vector>
#include <utility>
#include <memory>
#include <future>
#include "mindspore/ccsrc/utils/log_adapter.h"
#include "serving/ms_service.grpc.pb.h"
......@@ -40,7 +41,7 @@ namespace serving {
using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
session_ = inference::MSSession::CreateSession(device + "Inference", device_id);
session_ = inference::MSSession::CreateSession(device, device_id);
if (session_ == nullptr) {
MS_LOG(ERROR) << "Creat Session Failed";
return FAILED;
......@@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
MS_LOG(INFO) << "run Predict";
*outputs = session_->RunGraph(graph_id_, inputs);
MS_LOG(INFO) << "run Predict finished";
return SUCCESS;
}
......@@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) {
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
char *graphBuf = ReadFile(file_name.c_str(), &size);
if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return FAILED;
}
last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
if (last_graph_ == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED;
}
graph_id_ = session_->CompileGraph(last_graph_);
MS_LOG(INFO) << "Session Warmup";
MS_LOG(INFO) << "Session Warmup finished";
return SUCCESS;
}
......@@ -95,6 +101,9 @@ Status Session::Clear() {
}
namespace {
static const uint32_t uint32max = 0x7FFFFFFF;
std::promise<void> exit_requested;
const std::map<ms_serving::DataType, TypeId> type2id_map{
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
......@@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
}
TypeId type = iter->second;
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size());
memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
return ms_tensor;
}
......@@ -166,10 +175,7 @@ void ClearEnv() {
Session::Instance().Clear();
inference::ExitInference();
}
void HandleSignal(int sig) {
ClearEnv();
exit(0);
}
void HandleSignal(int sig) { exit_requested.set_value(); }
#ifdef ENABLE_D
static rtContext_t g_ctx = nullptr;
......@@ -247,6 +253,7 @@ Status Server::BuildAndStart() {
rtError_t rt_ret = rtCtxGetCurrent(&ctx);
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null";
ClearEnv();
return FAILED;
}
g_ctx = ctx;
......@@ -258,6 +265,7 @@ Status Server::BuildAndStart() {
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
grpc::ServerBuilder builder;
builder.SetOption(std::move(option));
builder.SetMaxMessageSize(uint32max);
// Listen on the given address without any authentication mechanism.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Register "service" as the instance through which we'll communicate with
......@@ -265,13 +273,20 @@ Status Server::BuildAndStart() {
builder.RegisterService(&service);
// Finally assemble the server.
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
if (server == nullptr) {
MS_LOG(ERROR) << "The serving server create failed";
ClearEnv();
return FAILED;
}
auto grpc_server_run = [&server]() { server->Wait(); };
std::thread serving_thread(grpc_server_run);
MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
// Wait for the server to shutdown. Note that some other thread must be
// responsible for shutting down the server for this call to ever return.
server->Wait();
auto exit_future = exit_requested.get_future();
exit_future.wait();
ClearEnv();
server->Shutdown();
serving_thread.join();
return SUCCESS;
}
} // namespace serving
} // namespace mindspore
......@@ -29,7 +29,6 @@
namespace mindspore {
namespace serving {
char *ReadFile(const char *file, size_t *size) {
if (file == nullptr) {
MS_LOG(ERROR) << "file is nullptr";
......@@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) {
}
std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
DIR *dir;
struct dirent *ptr;
DIR *dir = nullptr;
struct dirent *ptr = nullptr;
std::vector<std::string> SubDirs;
if ((dir = opendir(dir_path.c_str())) == NULL) {
......
......@@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) {
bool Option::ParseInt32(std::string *arg) {
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
char extra;
int32_t parsed_value;
if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) {
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
try {
parsed_value = std::stoi(arg->data());
} catch (std::invalid_argument) {
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
return false;
} else {
*int32_default_ = parsed_value;
}
*int32_default_ = parsed_value;
return true;
}
return false;
}
......@@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) {
bool Option::ParseFloat(std::string *arg) {
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
char extra;
float parsed_value;
if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) {
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
try {
parsed_value = std::stof(arg->data());
} catch (std::invalid_argument) {
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
return false;
} else {
*float_default_ = parsed_value;
}
*float_default_ = parsed_value;
return true;
}
return false;
}
......@@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); }
void Options::CreateOptions() {
args_ = std::make_shared<Arguments>();
std::vector<Option> options = {
Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"),
Option("model_name", &args_->model_name, "model name "),
Option("model_path", &args_->model_path, "the path of the model files"),
Option("device_id", &args_->device_id, "the device id, default is 0"),
Option("port", &args_->grpc_port,
"[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"),
Option("model_name", &args_->model_name, "[Required] model name "),
Option("model_path", &args_->model_path, "[Required] the path of the model files"),
Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"),
};
options_ = options;
}
......@@ -176,6 +175,14 @@ bool Options::CheckOptions() {
std::cout << "device_type only support Ascend right now" << std::endl;
return false;
}
if (args_->device_id > 7) {
std::cout << "the device_id should be in [0~7]" << std::endl;
return false;
}
if (args_->grpc_port < 1 || args_->grpc_port > 65535) {
std::cout << "the port should be in [1~65535]" << std::endl;
return false;
}
return true;
}
......@@ -238,6 +245,5 @@ void Options::Usage() {
<< option.usage_ << std::endl;
}
}
} // namespace serving
} // namespace mindspore
......@@ -22,7 +22,6 @@
namespace mindspore {
namespace serving {
struct Arguments {
int32_t grpc_port = 5500;
std::string grpc_socket_path;
......@@ -40,6 +39,7 @@ class Option {
Option(const std::string &name, bool *default_point, const std::string &usage);
Option(const std::string &name, std::string *default_point, const std::string &usage);
Option(const std::string &name, float *default_point, const std::string &usage);
~Option() = default;
private:
friend class Options;
......@@ -77,7 +77,6 @@ class Options {
std::vector<Option> options_;
std::shared_ptr<Arguments> args_;
};
} // namespace serving
} // namespace mindspore
......
......@@ -19,7 +19,6 @@
namespace mindspore {
namespace serving {
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
const std::string &model_version, const time_t &last_update_time)
: model_name_(model_name),
......
......@@ -25,7 +25,6 @@
namespace mindspore {
namespace serving {
volatile bool stop_poll = false;
std::string GetVersionFromPath(const std::string &path) {
......@@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() {
}
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
if (version_control_strategy_ == kLastest) {
auto path = SubDirs.empty() ? models_path_ : SubDirs.back();
std::string model_version = GetVersionFromPath(path);
time_t last_update_time = GetModifyTime(path);
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time);
std::string model_version = GetVersionFromPath(models_path_);
time_t last_update_time = GetModifyTime(models_path_);
MindSporeModelPtr model_ptr =
std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time);
valid_models_.emplace_back(model_ptr);
} else {
for (auto &dir : SubDirs) {
......@@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() {
MS_LOG(ERROR) << "There is no valid model for serving";
return FAILED;
}
Session::Instance().Warmup(valid_models_.back());
return SUCCESS;
auto ret = Session::Instance().Warmup(valid_models_.back());
return ret;
}
void VersionController::StartPollModelPeriodic() {
......@@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() {
}
void VersionController::StopPollModelPeriodic() {}
} // namespace serving
} // namespace mindspore
......@@ -64,7 +64,6 @@ class PeriodicFunction {
VersionController::VersionControllerStrategy version_control_strategy_;
std::vector<MindSporeModelPtr> valid_models_;
};
} // namespace serving
} // namespace mindspore
......
......@@ -214,6 +214,7 @@ PredictRequest ReadBertInput() {
class MSClient {
public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default;
std::string Predict(const std::string &type) {
// Data we are sending to the server.
......@@ -310,7 +311,6 @@ int main(int argc, char **argv) {
type = "add";
}
}
} else {
target_str = "localhost:5500";
type = "add";
......
......@@ -81,7 +81,7 @@ function checkopts()
checkopts "$@"
# switch to project root path, which contains clang-format config file '.clang-format'
cd "${SCRIPTS_PATH}/.." || exit 1
cd "${SCRIPTS_PATH}/../.." || exit 1
FMT_FILE_LIST='__format_files_list__'
......
......@@ -161,6 +161,7 @@ setup(
description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.',
long_description="\n\n".join([readme, release]),
long_description_content_type="text/markdown",
packages=find_packages(),
package_data=package_data,
include_package_data=True,
......
......@@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) {
EXPECT_TRUE(rc.IsOk());
uint64_t min = ai.min_key();
uint64_t max = ai.max_key();
EXPECT_EQ(min, 1);
EXPECT_EQ(max, 4);
auto r = ai.Search(3);
EXPECT_EQ(min, 0);
EXPECT_EQ(max, 3);
auto r = ai.Search(2);
auto &it = r.first;
EXPECT_EQ(it.value(), "b");
MS_LOG(INFO) << "Dump all the values using [] operator.";
......
......@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
};
void SetUp() {
elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R);
idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q);
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
}
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
......
......@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "instance_name") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "test");
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else {
MS_LOG(EXCEPTION) << "Test failed";
}
......
......@@ -4,6 +4,7 @@
"numParallelWorkers": 4,
"workerConnectorSize": 16,
"opConnectorSize": 16,
"seed": 5489
"seed": 5489,
"monitor_sampling_interval": 15
}
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册