提交 d9197b59 编写于 作者: L limingqi107

gpu optimize the use of reference count

上级 cae8a921
......@@ -184,6 +184,10 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph
bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto graph_id = graph->graph_id();
auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id];
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
// Reset the reference count.
mem_reuse_util_ptr->ResetDynamicUsedRefCount();
// The inputs and outputs memory of communication kernel need be continuous, so separate processing.
AllocCommunicationOpDynamicRes(graph);
......@@ -360,16 +364,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
if (kernel_ref_count_ptr == nullptr) {
continue;
}
// Can't free the output of graph.
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) {
continue;
}
kernel_ref_count_ptr->ref_count_dynamic_use_--;
if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) {
MS_LOG(EXCEPTION) << "Check dynamic reference count failed.";
}
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
mem_manager_->FreeMemFromMemPool(device_address);
// Reset the reference count.
kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_;
}
}
// Free the output of kernel, if output has no reference.
......
......@@ -288,6 +288,14 @@ void MemReuseUtil::SetGraphOutputRefCount() {
#endif
}
void MemReuseUtil::ResetDynamicUsedRefCount() {
for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) {
for (auto &ref_count : iter->second) {
ref_count->ref_count_dynamic_use_ = ref_count->ref_count_;
}
}
}
void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
if (!InitDynamicKernelRef(graph)) {
MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault";
......
......@@ -64,6 +64,8 @@ class MemReuseUtil {
void SetReuseRefCount();
// Set the reference count of graph output specially.
void SetGraphOutputRefCount();
// Reset the dynamic used reference count by ref_count_.
void ResetDynamicUsedRefCount();
KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx);
KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx);
......
......@@ -161,7 +161,8 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li
total_ori_value_size_ = CalculOriValue(graph);
total_ori_dy_size_ = CalculOriDy(graph);
total_ori_wkspace_size_ = CalculOriWk(graph);
std::string filename = "./memreuse.ir";
std::string graph_id = std::to_string(graph->graph_id());
std::string filename = "./memreuse_" + graph_id + ".ir";
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册