提交 734f8a7f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!244 fix ref pass visit graph bug

Merge pull request !244 from dinghao/master
......@@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
......@@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
......
......@@ -22,6 +22,7 @@
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
......@@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
}
} // namespace
const BaseRef DealRefTransAndCast::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
......
......@@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass {
public:
explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {}
~DealRefTransAndCast() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
......
......@@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
bool change = (new_node != nullptr);
if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node);
(void)seen_node.erase(node);
} else if (new_node == nullptr) {
new_node = node;
}
......
......@@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register
"dtype": [
"bool",
"float","float","float","float","float","float","float","float","float","float",
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
],
"format": [
"DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ",
"DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ"
],
"name": "src",
......@@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register
"dtype": [
"bool",
"float","float","float","float","float","float","float","float","float","float",
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16"
"float16","float16","float16","float16","float16","float16","float16","float16","float16","float16",
"uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16"
],
"format": [
"NC1HWC0",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN",
"NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN"
],
"name": "dst",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册