提交 39b9e831 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!291 disable memory reuse for GetNext op

Merge pull request !291 from caifubi/dev-getnext-mem-reuse-off
...@@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in ...@@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
AssignCommunicationNodeOutputMem(flag, node); AssignCommunicationNodeOutputMem(flag, node);
return; return;
} }
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
MS_LOG(INFO) << "GetNext disable mem_reuse";
flag = kDynamicMem;
}
auto kernel_mod = AnfAlgo::GetKernelMod(node); auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList(); auto output_sizes = kernel_mod->GetOutputSizeList();
......
...@@ -857,5 +857,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { ...@@ -857,5 +857,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
} }
return false; return false;
} }
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName;
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "utils/contract.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
...@@ -175,6 +176,7 @@ class AnfRuntimeAlgorithm { ...@@ -175,6 +176,7 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl // get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;
......
...@@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2"; ...@@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2";
constexpr auto kBNGrad3OpName = "BNGrad3"; constexpr auto kBNGrad3OpName = "BNGrad3";
constexpr auto kClearZeroOpName = "ClearZero"; constexpr auto kClearZeroOpName = "ClearZero";
constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
constexpr auto kGetNextOpName = "GetNext";
constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllReduceOpName = "AllReduce";
constexpr auto kAllGatherOpName = "AllGather"; constexpr auto kAllGatherOpName = "AllGather";
constexpr auto kBroadcastOpName = "Broadcast"; constexpr auto kBroadcastOpName = "Broadcast";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册