提交 65eacc95 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1787 optimize transdata for pynative mode

Merge pull request !1787 from chujinjin/optimize_transdata_for_pynative
......@@ -16,11 +16,13 @@
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace opt {
......@@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs});
}
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
return true;
}
return false;
}
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
......@@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
}
} // namespace opt
......
......@@ -21,6 +21,7 @@
#include "pre_activate/common/pass_manager.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
......@@ -103,6 +104,9 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
* return output
*
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
// Do insert_trans_op_ pass of hardware opt
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
......
......@@ -20,6 +20,8 @@
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "debug/anf_ir_dump.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h"
......@@ -91,6 +93,9 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
* transdata = Transdata(transpose)
* return transdata
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
......
......@@ -19,6 +19,7 @@
#include "device/kernel_info.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h"
......@@ -76,6 +77,9 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
* transdata = Transdata(transpose)
* return transdata
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
......
......@@ -30,6 +30,7 @@
#include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
......@@ -71,6 +72,9 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
* output = make_tuple(res)
* return output
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
// Renormalize func_graph to infer and set shape and type information.
std::vector<int> shp{2, 32, 224, 224};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册