diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 432d88e7a4fefb0814b4d85b1224e4b776fb91d4..023838c3a5ef4f7d99d3701ab6834e0c73f47080 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr &kernel mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); @@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index fd2061141505ca579ad18298d2752786bacf66f9..81e5c4b48603c5f4b6908267a715ff586bf6ebf5 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -22,6 +22,7 @@ #include "kernel/oplib/oplib.h" #include "session/anf_runtime_algorithm.h" #include "session/kernel_graph.h" +#include "pre_activate/common/helper.h" namespace mindspore { namespace opt { @@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn } } // namespace +const BaseRef DealRefTransAndCast::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { if (node == nullptr || !node->isa()) { return nullptr; } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!AnfAlgo::IsRealCNodeKernel(cnode)) { diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h index 9ed55d8b297858698f288b4995f0d752f42af8ad..1b54a7b111d17ed84d4c0fbfc64f1141f37c0b88 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h @@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass { public: explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} ~DealRefTransAndCast() override = default; + const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; } // namespace opt diff --git a/mindspore/ccsrc/pre_activate/common/node_pass.cc b/mindspore/ccsrc/pre_activate/common/node_pass.cc index cd213f8263d0b2aea42b2c8b7f3954309793752d..a6e93d2f0742a369f3d467fba8112c5a0f9f8632 100644 --- a/mindspore/ccsrc/pre_activate/common/node_pass.cc +++ b/mindspore/ccsrc/pre_activate/common/node_pass.cc @@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { bool change = (new_node != nullptr); if (new_node != nullptr && new_node != node) { (void)manager->Replace(node, new_node); + (void)seen_node.erase(node); } else if (new_node == nullptr) { new_node = node; } diff --git a/mindspore/ops/_op_impl/tbe/trans_data.py b/mindspore/ops/_op_impl/tbe/trans_data.py index 1b7c8fa25df3455c9000511687b4b930d48566ca..c6628c7638174f5c5e3540ab07d803ff06cbe775 100644 --- a/mindspore/ops/_op_impl/tbe/trans_data.py +++ b/mindspore/ops/_op_impl/tbe/trans_data.py @@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register "dtype": [ "bool", "float","float","float","float","float","float","float","float","float","float", - "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" + "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16", + "uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16" ], "format": [ "DefaultFormat", "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", + "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ" ], "name": "src", @@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register "dtype": [ "bool", "float","float","float","float","float","float","float","float","float","float", - "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" + "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16", + "uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16" ], "format": [ "NC1HWC0", "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", + "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN" ], "name": "dst",