提交 65eacc95 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1787 optimize transdata for pynative mode

Merge pull request !1787 from chujinjin/optimize_transdata_for_pynative
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
#include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/ascend/format_type/insert_trans_op.h"
#include <memory> #include <memory>
#include <vector>
#include "utils/utils.h" #include "utils/utils.h"
#include "pre_activate/ascend/ascend_helper.h" #include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h" #include "device/kernel_info.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const { ...@@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs}); return VectorRef({V, Xs});
} }
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
return true;
}
return false;
}
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const { const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
...@@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An ...@@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "====process op: " << node->DebugString(); MS_LOG(DEBUG) << "====process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_); return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
} }
} // namespace opt } // namespace opt
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "pre_activate/common/pass_manager.h" #include "pre_activate/common/pass_manager.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h" #include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#define private public #define private public
#define protected public #define protected public
...@@ -103,6 +104,9 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) { ...@@ -103,6 +104,9 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
* return output * return output
* *
*/ */
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0"); auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
// Do insert_trans_op_ pass of hardware opt // Do insert_trans_op_ pass of hardware opt
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>(); auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
#include "utils/context/ms_context.h"
#define private public #define private public
#define protected public #define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/ascend/format_type/insert_trans_op.h"
...@@ -91,6 +93,9 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { ...@@ -91,6 +93,9 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
* transdata = Transdata(transpose) * transdata = Transdata(transpose)
* return transdata * return transdata
*/ */
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
std::vector<int> shp{2, 4, 8, 16}; std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "device/kernel_info.h" #include "device/kernel_info.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
#define private public #define private public
#define protected public #define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/ascend/format_type/insert_trans_op.h"
...@@ -76,6 +77,9 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { ...@@ -76,6 +77,9 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
* transdata = Transdata(transpose) * transdata = Transdata(transpose)
* return transdata * return transdata
*/ */
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
std::vector<int> shp{2, 4, 8, 16}; std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h" #include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#define private public #define private public
#define protected public #define protected public
...@@ -71,6 +72,9 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { ...@@ -71,6 +72,9 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
* output = make_tuple(res) * output = make_tuple(res)
* return output * return output
*/ */
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before"); FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
// Renormalize func_graph to infer and set shape and type information. // Renormalize func_graph to infer and set shape and type information.
std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册