提交 2445ffdf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1170 fix memreuse to support large batchsize

Merge pull request !1170 from yangjie159/fix_memreuse_to_support_large_batchsize
......@@ -21,8 +21,8 @@
namespace mindspore {
namespace device {
namespace ascend {
const uint64_t kAscendDeviceMemGB = 24;
const uint64_t kAscendMemPoolGB = 6;
const uint64_t kAscendDeviceMemGB = 26;
const uint64_t kAscendMemPoolGB = 4;
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30);
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30);
......
......@@ -401,6 +401,15 @@ bool BestFitMemReuse::IsReusableStream(uint32_t curr_stream_id, uint32_t target_
return curr_parallel_set.find(target_stream_id) == curr_parallel_set.end();
}
bool BestFitMemReuse::IsRelease(const std::string &kernel_name) {
// unable_used_node include the node type that output tensor cannot be released,
// even if its refcount is equal to zero.
std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(),
prim::kPrimFusedBatchNorm->name(),
prim::kPrimFusedBatchNormGrad->name()};
return unable_used_node.find(kernel_name) == unable_used_node.end();
}
void BestFitMemReuse::CheckTensorIndex(int tensor_index) const {
if (tensor_index < 0) {
MS_LOG(EXCEPTION) << "warning, please check tensor info.";
......@@ -437,6 +446,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
// update node input tensor refcount, and membuf list status
UpdateNodeInputAndMembuf(op_def_ptr.get());
// check node output tensor which refcount is equal to zero
if (IsRelease(op_def_ptr->kernel_name())) {
ReleaseNodeUnusedOutput(op_def_ptr.get());
}
#ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_);
++op_num;
......
......@@ -102,6 +102,8 @@ class BestFitMemReuse {
size_t GetAllocatedSize();
// If the target stream can be reused by current stream
bool IsReusableStream(uint32_t curr_stream_id, uint32_t target_stream_id);
// return false, when the node output cannot be released
bool IsRelease(const std::string &kernel_name);
// set tensor_def and op_def
void set_tensor_ptr_list(const std::vector<KernelRefCountPtr> &tensor_ptr_list) {
tensor_ptr_list_ = tensor_ptr_list;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册