提交 539c3ad5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2660 optimize is all nop node detect in mem reuse

Merge pull request !2660 from laiyongqiang/gpu_opt
...@@ -103,6 +103,7 @@ bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { ...@@ -103,6 +103,7 @@ bool MemReuseUtil::InitDynamicWorkspaceKernelRef() {
bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
graph_ = graph; graph_ = graph;
is_all_nop_node_ = opt::IsAllNopNode(graph);
if (!InitDynamicOutputKernelRef()) { if (!InitDynamicOutputKernelRef()) {
MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; MS_LOG(INFO) << "InitDynamicOutputKernelRef fail";
return false; return false;
...@@ -223,7 +224,6 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { ...@@ -223,7 +224,6 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) {
} }
KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) {
auto is_all_nop_node = opt::IsAllNopNode(graph_);
if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number "
<< AnfAlgo::GetInputTensorNum(kernel); << AnfAlgo::GetInputTensorNum(kernel);
...@@ -231,7 +231,7 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t ...@@ -231,7 +231,7 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
auto input_node = kernel->input(input_idx + 1); auto input_node = kernel->input(input_idx + 1);
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
session::KernelWithIndex kernel_input; session::KernelWithIndex kernel_input;
if (is_all_nop_node) { if (is_all_nop_node_) {
// The graph does not remove the nop node. // The graph does not remove the nop node.
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
} else { } else {
...@@ -265,7 +265,6 @@ void MemReuseUtil::SetKernelDefMap() { ...@@ -265,7 +265,6 @@ void MemReuseUtil::SetKernelDefMap() {
} }
void MemReuseUtil::SetKernelDefInputs() { void MemReuseUtil::SetKernelDefInputs() {
auto is_all_nop_node = opt::IsAllNopNode(graph_);
for (const auto &kernel : graph_->execution_order()) { for (const auto &kernel : graph_->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
auto key = kernel.get(); auto key = kernel.get();
...@@ -282,7 +281,7 @@ void MemReuseUtil::SetKernelDefInputs() { ...@@ -282,7 +281,7 @@ void MemReuseUtil::SetKernelDefInputs() {
auto input_node = AnfAlgo::GetInputNode(kernel, i); auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
session::KernelWithIndex input; session::KernelWithIndex input;
if (is_all_nop_node) { if (is_all_nop_node_) {
// The graph does not remove the nop node. // The graph does not remove the nop node.
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
} else { } else {
...@@ -349,11 +348,10 @@ void MemReuseUtil::SetSummaryNodesRefCount() { ...@@ -349,11 +348,10 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
} }
void MemReuseUtil::SetGraphOutputRefCount() { void MemReuseUtil::SetGraphOutputRefCount() {
auto is_all_nop_node = opt::IsAllNopNode(graph_);
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
for (const auto &node : nodes) { for (const auto &node : nodes) {
session::KernelWithIndex kernel_input; session::KernelWithIndex kernel_input;
if (is_all_nop_node) { if (is_all_nop_node_) {
// The graph does not remove the nop node. // The graph does not remove the nop node.
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false);
} else { } else {
......
...@@ -42,7 +42,7 @@ class MemReuseUtil { ...@@ -42,7 +42,7 @@ class MemReuseUtil {
KernelRefCountPtrList total_refs_list_; KernelRefCountPtrList total_refs_list_;
KernelRefCountPtrList total_wk_ref_list_; KernelRefCountPtrList total_wk_ref_list_;
KernelRefs kernel_workspace_refs_; KernelRefs kernel_workspace_refs_;
MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr) {} MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {}
~MemReuseUtil() { ~MemReuseUtil() {
if (graph_ != nullptr) { if (graph_ != nullptr) {
graph_ = nullptr; graph_ = nullptr;
...@@ -87,6 +87,7 @@ class MemReuseUtil { ...@@ -87,6 +87,7 @@ class MemReuseUtil {
private: private:
int util_index_; int util_index_;
const KernelGraph *graph_; const KernelGraph *graph_;
bool is_all_nop_node_;
KernelRefCountPtrList ref_list_; KernelRefCountPtrList ref_list_;
KernelDefPtrMaps kernel_def_ptr_list_; KernelDefPtrMaps kernel_def_ptr_list_;
KernelRefCountPtrList last_ref_list_; KernelRefCountPtrList last_ref_list_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册