Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4d831966
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4d831966
编写于
7月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3115 Remove transdata and cast for internal outputs
Merge pull request !3115 from YuJianfeng/master
上级
36b0ca61
188d74f1
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
564 addition
and
30 deletion
+564
-30
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+3
-0
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
+12
-2
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
...ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
+12
-0
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
...c/backend/optimizer/ascend/format_type/insert_trans_op.cc
+5
-11
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc
...nd/optimizer/ascend/format_type/remove_internal_output.cc
+83
-0
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h
...end/optimizer/ascend/format_type/remove_internal_output.h
+51
-0
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+31
-13
mindspore/ccsrc/backend/session/kernel_graph.h
mindspore/ccsrc/backend/session/kernel_graph.h
+3
-3
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+6
-1
tests/st/host_device/test_host_device_lenet.py
tests/st/host_device/test_host_device_lenet.py
+89
-0
tests/st/ops/cpu/test_sparse_apply_adam_op.py
tests/st/ops/cpu/test_sparse_apply_adam_op.py
+4
-0
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
+4
-0
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py
+4
-0
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
...ctivate/ascend/format_type/remove_internal_output_test.cc
+174
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py
...t/gtest_input/pre_activate/remove_internal_output_test.py
+83
-0
未找到文件。
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
4d831966
...
...
@@ -96,6 +96,7 @@
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
...
...
@@ -199,6 +200,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
data_layout_pm
->
AddPass
(
std
::
make_shared
<
OptimizeDependence
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
TransDataSplit
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
RemoveInternalOutputTransOp
>
());
optimizer
->
AddPassManager
(
data_layout_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
@@ -220,6 +222,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
LayerNormBetaGammaBackpropFusion
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
ConvertUnSupportNodeToAICPU
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
RemoveInternalOutputCast
>
());
optimizer
->
AddPassManager
(
mixed_precision_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
浏览文件 @
4d831966
...
...
@@ -142,6 +142,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
MS_EXCEPTION_IF_NULL
(
node
);
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
for
(
size_t
output_idx
=
0
;
output_idx
<
AnfAlgo
::
GetOutputTensorNum
(
node
);
++
output_idx
)
{
std
::
string
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
output_idx
);
if
(
output_format
==
kOpFormat_NC1KHKWHWC0
)
{
...
...
@@ -151,7 +152,11 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
auto
tuple_getitem
=
CreatTupleGetItemNode
(
func_graph
,
node
,
output_idx
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_idx
);
if
(
kCommonFormatSet
.
find
(
output_format
)
==
kCommonFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
make_tuple_inputs
.
emplace_back
(
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
false
));
auto
trans_op
=
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
false
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
trans_op
,
output_idx
,
0
);
}
make_tuple_inputs
.
emplace_back
(
trans_op
);
}
else
{
// No need insert trans op.
make_tuple_inputs
.
push_back
(
tuple_getitem
);
...
...
@@ -249,9 +254,14 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
if
(
outputs_num
==
0
)
{
return
node
;
}
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
// Single output
if
(
outputs_num
==
1
&&
(
!
AnfAlgo
::
IsTupleOutput
(
node
)))
{
return
InsertTransOpForSingleOutput
(
func_graph
,
node
,
kernel_select
);
auto
new_node
=
InsertTransOpForSingleOutput
(
func_graph
,
node
,
kernel_select
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
new_node
);
}
return
new_node
;
}
// Multiple output
return
InsertTransOpForMultipleOutput
(
func_graph
,
node
,
kernel_select
);
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
浏览文件 @
4d831966
...
...
@@ -40,6 +40,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
AbstractBasePtrList
abstract_list
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
for
(
size_t
output_idx
=
0
;
output_idx
<
AnfAlgo
::
GetOutputTensorNum
(
cnode
);
++
output_idx
)
{
AnfNodePtr
replace_node
=
nullptr
;
const
auto
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
cnode
,
output_idx
);
...
...
@@ -64,6 +65,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL
(
replace_node
);
replace_node
->
set_scope
(
cnode
->
scope
());
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
replace_node
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
cnode
))
{
kernel_graph
->
ReplaceInternalOutput
(
cnode
,
replace_node
,
output_idx
,
0
);
}
}
else
{
replace_node
=
getitem
;
}
...
...
@@ -87,6 +91,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return
cnode
;
}
MS_EXCEPTION_IF_NULL
(
cnode
->
Type
());
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
// Single output
if
(
!
cnode
->
Type
()
->
isa
<
Tuple
>
())
{
if
(
!
need_insert_cast
[
0
])
{
...
...
@@ -109,6 +114,9 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL
(
replace_node
);
replace_node
->
set_scope
(
cnode
->
scope
());
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
replace_node
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
cnode
))
{
kernel_graph
->
ReplaceInternalOutput
(
cnode
,
replace_node
);
}
}
return
replace_node
;
}
...
...
@@ -188,6 +196,10 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
new_node
=
InsertCastForInput
(
func_graph
,
cnode
);
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
new_node
);
}
// process output
return
InsertCastForOutput
(
func_graph
,
new_node
,
std
::
vector
<
bool
>
(
AnfAlgo
::
GetOutputTensorNum
(
new_node
),
true
));
}
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
浏览文件 @
4d831966
...
...
@@ -46,14 +46,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if
(
node
==
nullptr
||
!
AnfAlgo
::
IsRealKernel
(
node
))
{
return
nullptr
;
}
AnfNodePtr
front_node
;
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
MS_LOG
(
DEBUG
)
<<
"process op: "
<<
node
->
DebugString
();
AnfNodePtr
new_node
=
InsertTransOpForInput
(
func_graph
,
node
,
kernel_select_
);
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
front_node
=
kernel_graph
->
GetFrontNodeByInternalOutput
(
node
);
kernel_graph
->
ReplaceInternalOutput
(
node
,
new_
node
);
}
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
MS_LOG
(
DEBUG
)
<<
"====process op: "
<<
node
->
DebugString
();
AnfNodePtr
new_node
=
InsertTransOpForInput
(
func_graph
,
node
,
kernel_select_
);
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
&&
!
ms_context
->
enable_pynative_hook
())
{
...
...
@@ -61,12 +60,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
return
new_node
;
}
}
auto
final_node
=
InsertTransOpForOutput
(
func_graph
,
new_node
,
kernel_select_
);
if
(
kernel_graph
!=
nullptr
&&
front_node
!=
nullptr
)
{
auto
old_node
=
kernel_graph
->
GetInternalOutputByFrontNode
(
front_node
);
kernel_graph
->
ReplaceInternalOutput
(
old_node
,
final_node
);
}
return
final_node
;
return
InsertTransOpForOutput
(
func_graph
,
new_node
,
kernel_select_
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc
0 → 100644
浏览文件 @
4d831966
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
bool
UsedForOutputOnly
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
node_users
=
manager
->
node_users
();
auto
iter
=
node_users
.
find
(
node
);
if
(
iter
==
node_users
.
end
())
{
return
false
;
}
const
auto
&
node_set
=
iter
->
second
;
for
(
const
auto
&
node_index
:
node_set
)
{
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
node_index
.
first
,
prim
::
kPrimMakeTuple
))
{
return
false
;
}
}
return
true
;
}
}
// namespace
const
BaseRef
RemoveInternalOutputTransOp
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kTransDataOpName
);
return
VectorRef
({
prim
,
X
});
}
const
BaseRef
RemoveInternalOutputCast
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
return
VectorRef
({
prim
::
kPrimCast
,
X
});
}
const
AnfNodePtr
RemoveInternalOutput
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
if
(
kernel_graph
==
nullptr
)
{
return
nullptr
;
}
if
(
!
kernel_graph
->
IsInternalOutput
(
node
))
{
return
nullptr
;
}
if
(
!
UsedForOutputOnly
(
func_graph
,
node
))
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
CheckCNodeInputSize
(
cnode
,
kTransOpInputNum
);
auto
input_node
=
cnode
->
input
(
1
);
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
input_node
,
prim
::
kPrimTupleGetItem
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
input_node
);
}
else
{
auto
tuple_getitem
=
input_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
int
idx
=
AnfAlgo
::
GetTupleGetItemOutIndex
(
tuple_getitem
);
AnfNodePtr
real_input_node
=
AnfAlgo
::
GetTupleGetItemRealInput
(
tuple_getitem
);
kernel_graph
->
ReplaceInternalOutput
(
node
,
real_input_node
,
0
,
idx
);
}
return
input_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h
0 → 100644
浏览文件 @
4d831966
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace
mindspore
{
namespace
opt
{
class
RemoveInternalOutput
:
public
PatternProcessPass
{
public:
explicit
RemoveInternalOutput
(
const
std
::
string
&
name
,
bool
multigraph
=
true
)
:
PatternProcessPass
(
name
,
multigraph
)
{}
~
RemoveInternalOutput
()
override
=
default
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
};
class
RemoveInternalOutputTransOp
:
public
RemoveInternalOutput
{
public:
explicit
RemoveInternalOutputTransOp
(
bool
multigraph
=
true
)
:
RemoveInternalOutput
(
"remove_internal_output_trans_op"
,
multigraph
)
{}
~
RemoveInternalOutputTransOp
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
RemoveInternalOutputCast
:
public
RemoveInternalOutput
{
public:
explicit
RemoveInternalOutputCast
(
bool
multigraph
=
true
)
:
RemoveInternalOutput
(
"remove_internal_output_cast"
,
multigraph
)
{}
~
RemoveInternalOutputCast
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
4d831966
...
...
@@ -929,10 +929,15 @@ void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodeP
}
MS_LOG
(
INFO
)
<<
"Add internal node "
<<
node
->
DebugString
()
<<
" with front node "
<<
front_node
->
DebugString
();
front_to_internal_outputs_map_
[
front_node
]
=
node
;
internal_outputs_to_front_map_
[
node
]
=
front_node
;
int
output_idx
=
0
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
front_node
,
prim
::
kPrimTupleGetItem
))
{
output_idx
=
AnfAlgo
::
GetTupleGetItemOutIndex
(
front_node
->
cast
<
CNodePtr
>
());
}
internal_outputs_to_front_map_
[
node
][
output_idx
]
=
front_node
;
}
void
KernelGraph
::
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
)
{
void
KernelGraph
::
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
,
int
src_output_idx
,
int
dst_output_idx
)
{
if
(
new_node
==
nullptr
||
node
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"New node or node is nullptr"
;
return
;
...
...
@@ -947,9 +952,30 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
return
;
}
MS_LOG
(
INFO
)
<<
"Replace internal node "
<<
node
->
DebugString
()
<<
" To "
<<
new_node
->
DebugString
();
internal_outputs_to_front_map_
[
new_node
]
=
iter
->
second
;
front_to_internal_outputs_map_
[
iter
->
second
]
=
new_node
;
internal_outputs_to_front_map_
.
erase
(
iter
);
auto
&
front_nodes
=
iter
->
second
;
// Move all front nodes to new node mapping
if
(
src_output_idx
==
-
1
)
{
internal_outputs_to_front_map_
[
new_node
]
=
front_nodes
;
for
(
const
auto
&
front_node_iter
:
front_nodes
)
{
front_to_internal_outputs_map_
[
front_node_iter
.
second
]
=
new_node
;
}
internal_outputs_to_front_map_
.
erase
(
iter
);
return
;
}
// Move specified front node to new node mapping
int
index
=
SizeToInt
(
src_output_idx
);
auto
front_node_iter
=
front_nodes
.
find
(
index
);
if
(
front_node_iter
==
front_nodes
.
end
())
{
MS_LOG
(
INFO
)
<<
"The output "
<<
src_output_idx
<<
" of node "
<<
node
->
DebugString
()
<<
" is not an internal node"
;
return
;
}
auto
front_node
=
front_node_iter
->
second
;
internal_outputs_to_front_map_
[
new_node
][
dst_output_idx
]
=
front_node
;
front_to_internal_outputs_map_
[
front_node
]
=
new_node
;
front_nodes
.
erase
(
index
);
if
(
front_nodes
.
empty
())
{
internal_outputs_to_front_map_
.
erase
(
iter
);
}
}
AnfNodePtr
KernelGraph
::
GetInternalOutputByFrontNode
(
const
AnfNodePtr
&
front_node
)
const
{
...
...
@@ -967,14 +993,6 @@ bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
return
false
;
}
AnfNodePtr
KernelGraph
::
GetFrontNodeByInternalOutput
(
const
AnfNodePtr
&
node
)
const
{
auto
iter
=
internal_outputs_to_front_map_
.
find
(
node
);
if
(
iter
!=
internal_outputs_to_front_map_
.
end
())
{
return
iter
->
second
;
}
return
nullptr
;
}
void
KernelGraph
::
AddFinalOutputKernel
(
const
AnfNodePtr
&
node
)
{
if
(
node
==
nullptr
)
{
return
;
...
...
mindspore/ccsrc/backend/session/kernel_graph.h
浏览文件 @
4d831966
...
...
@@ -148,10 +148,10 @@ class KernelGraph : public FuncGraph {
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
summary_nodes
()
const
{
return
summary_nodes_
;
}
void
set_summary_nodes
(
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
nodes
)
{
summary_nodes_
=
nodes
;
}
void
AddInternalOutput
(
const
AnfNodePtr
&
front_node
,
const
AnfNodePtr
&
node
);
void
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
);
void
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
,
int
src_output_idx
=
-
1
,
int
dst_output_idx
=
-
1
);
AnfNodePtr
GetInternalOutputByFrontNode
(
const
AnfNodePtr
&
front_node
)
const
;
bool
IsInternalOutput
(
const
AnfNodePtr
&
node
)
const
;
AnfNodePtr
GetFrontNodeByInternalOutput
(
const
AnfNodePtr
&
node
)
const
;
void
AddFinalOutputKernel
(
const
AnfNodePtr
&
node
);
bool
IsFinalOutputKernel
(
const
AnfNodePtr
&
node
)
const
;
uint32_t
current_epoch
()
const
{
return
current_epoch_
;
}
...
...
@@ -223,7 +223,7 @@ class KernelGraph : public FuncGraph {
CNodePtr
end_goto_
;
bool
null_output_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
front_to_internal_outputs_map_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
internal_outputs_to_front_map_
;
std
::
unordered_map
<
AnfNodePtr
,
std
::
unordered_map
<
int
,
AnfNodePtr
>
>
internal_outputs_to_front_map_
;
std
::
set
<
AnfNodePtr
>
final_output_kernels_
;
uint32_t
current_epoch_
;
};
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
4d831966
...
...
@@ -300,7 +300,11 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG
(
INFO
)
<<
"No corresponding internal output for output node"
;
return
;
}
auto
real_kernel
=
AnfAlgo
::
VisitKernel
(
ref_node
,
0
);
size_t
output_idx
=
0
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
out_node
,
prim
::
kPrimTupleGetItem
))
{
output_idx
=
AnfAlgo
::
GetTupleGetItemOutIndex
(
out_node
->
cast
<
CNodePtr
>
());
}
auto
real_kernel
=
AnfAlgo
::
VisitKernel
(
ref_node
,
output_idx
);
auto
ref_real_node
=
real_kernel
.
first
;
auto
ref_real_node_index
=
real_kernel
.
second
;
if
(
ref_real_node
->
isa
<
CNode
>
()
&&
node_graph
->
IsInternalOutput
(
ref_real_node
)
&&
...
...
@@ -325,6 +329,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
builder
.
SetOutputsFormat
({
format
});
d_kernel_info
->
set_select_kernel_build_info
(
builder
.
Build
());
AnfAlgo
::
SetOutputAddr
(
address
,
0
,
parameter
.
get
());
AnfAlgo
::
SetOutputInferTypeAndShape
({
type
},
{
AnfAlgo
::
GetOutputInferShape
(
parameter
,
0
)},
parameter
.
get
());
}
}
...
...
tests/st/host_device/test_host_device_lenet.py
0 → 100644
浏览文件 @
4d831966
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
LeNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
batch_size
=
32
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
has_bias
=
False
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
has_bias
=
False
,
pad_mode
=
'valid'
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
reshape
=
P
.
Reshape
()
self
.
fc1
=
nn
.
Dense
(
400
,
120
)
self
.
fc1
.
matmul
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc1
.
bias_add
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc2
.
matmul
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc2
.
bias_add
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc3
=
nn
.
Dense
(
84
,
10
)
self
.
fc3
.
matmul
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc3
.
bias_add
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
def
construct
(
self
,
input_x
):
output
=
self
.
conv1
(
input_x
)
output
=
self
.
relu
(
output
)
output
=
self
.
pool
(
output
)
output
=
self
.
conv2
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
pool
(
output
)
output
=
self
.
reshape
(
output
,
(
self
.
batch_size
,
-
1
))
output
=
self
.
fc1
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
fc2
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
fc3
(
output
)
return
output
def
train
(
net
,
data
,
label
):
learning_rate
=
0.01
momentum
=
0.9
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
learning_rate
,
momentum
)
criterion
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
net_with_criterion
=
WithLossCell
(
net
,
criterion
)
train_network
=
TrainOneStepCell
(
net_with_criterion
,
optimizer
)
# optimizer
train_network
.
set_train
()
res
=
train_network
(
data
,
label
)
print
(
"+++++++++Loss+++++++++++++"
)
print
(
res
)
print
(
"+++++++++++++++++++++++++++"
)
diff
=
res
.
asnumpy
()[
0
]
-
2.3025851
assert
np
.
all
(
diff
<
1.e-7
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_lenet
():
data
=
Tensor
(
np
.
ones
([
32
,
1
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
ones
([
32
]).
astype
(
np
.
int32
))
net
=
LeNet
()
train
(
net
,
data
,
label
)
tests/st/ops/cpu/test_sparse_apply_adam_op.py
浏览文件 @
4d831966
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
@@ -43,6 +44,9 @@ class Net(nn.Cell):
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_onecard
def
test_net
():
gradient
=
Tensor
(
np
.
ones
([
3
,
3
,
3
]).
astype
(
np
.
float32
))
indices
=
Tensor
([
0
,
1
,
2
],
mstype
.
int32
)
...
...
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
浏览文件 @
4d831966
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
@@ -35,6 +36,9 @@ class Net(nn.Cell):
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_onecard
def
test_net
():
gradient
=
Tensor
(
np
.
ones
([
3
,
3
,
3
]).
astype
(
np
.
float32
))
indices
=
Tensor
([
0
,
1
,
2
],
mstype
.
int32
)
...
...
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py
浏览文件 @
4d831966
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
@@ -37,6 +38,9 @@ class Net(nn.Cell):
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_onecard
def
test_net
():
gradient
=
Tensor
(
np
.
ones
([
3
,
3
,
3
]).
astype
(
np
.
float32
))
indices
=
Tensor
([
0
,
1
,
2
],
mstype
.
int32
)
...
...
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
0 → 100644
浏览文件 @
4d831966
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#undef private
#undef protected
namespace
mindspore
{
namespace
opt
{
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
class
TestHWRemoveInternalOutput
:
public
BackendCommon
{
public:
TestHWRemoveInternalOutput
()
:
getPyFun_
(
"gtest_input.pre_activate.remove_internal_output_test"
,
true
)
{}
~
TestHWRemoveInternalOutput
()
override
=
default
;
AnfNodePtr
GetMakeTuple
(
const
KernelGraphPtr
&
kg
)
{
auto
ret
=
kg
->
get_return
();
MS_EXCEPTION_IF_NULL
(
ret
);
auto
make_tuple
=
ret
->
input
(
1
);
return
make_tuple
;
}
KernelGraphPtr
GetSingleOutputGraph
(
const
std
::
string
&
func_name
,
const
std
::
string
&
sub_func_name
)
{
FuncGraphPtr
g
=
getPyFun_
.
CallAndParseRet
(
func_name
,
sub_func_name
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
make_tuple
=
GetMakeTuple
(
kg
);
auto
add
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
add
);
kg
->
AddInternalOutput
(
add
,
add
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
,
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
(),
kFloat32
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_NC1HWC0
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
add
.
get
());
return
kg
;
}
KernelGraphPtr
GetMutilpleOutputGraph
(
const
std
::
string
&
func_name
,
const
std
::
string
&
sub_func_name
)
{
FuncGraphPtr
g
=
getPyFun_
.
CallAndParseRet
(
func_name
,
sub_func_name
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
{
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
output_make_tuple
=
GetMakeTuple
(
kg
);
auto
make_tuple
=
output_make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
make_tuple
);
auto
tuple_getitem1
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem1
);
auto
tuple_getitem2
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem2
);
auto
max_pool
=
tuple_getitem1
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
max_pool
);
kg
->
AddInternalOutput
(
tuple_getitem1
,
max_pool
);
kg
->
AddInternalOutput
(
tuple_getitem2
,
max_pool
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_NC1HWC0
,
kOpFormat_NC1HWC0
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
(),
kFloat16
->
type_id
()});
max_pool
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
max_pool
.
get
());
return
kg
;
}
UT
::
PyFuncGraphFetcher
getPyFun_
;
};
class
MockRemoveInternalOutputTransOpKernelSelect
:
public
KernelSelect
{
public:
MockRemoveInternalOutputTransOpKernelSelect
()
=
default
;
~
MockRemoveInternalOutputTransOpKernelSelect
()
override
=
default
;
void
SelectKernel
(
const
CNodePtr
&
cnode
)
override
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_NC1HWC0
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
()});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
};
TEST_F
(
TestHWRemoveInternalOutput
,
test_remove_internal_output_trans_op_for_single_output
)
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_execution_mode
(
kGraphMode
);
auto
kg
=
GetSingleOutputGraph
(
"test_remove_internal_output_trans_op_for_single_output"
,
"before"
);
// insert trans op for output
auto
graph_optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
insert_trans_op_pass
=
std
::
make_shared
<
opt
::
InsertTransOp
>
();
insert_trans_op_pass
->
kernel_select_
=
std
::
make_shared
<
MockRemoveInternalOutputTransOpKernelSelect
>
();
pass_manager
->
AddPass
(
insert_trans_op_pass
);
graph_optimizer
->
AddPassManager
(
pass_manager
);
auto
new_g
=
graph_optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_single_output"
,
"after_insert_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_g
));
auto
make_tuple
=
GetMakeTuple
(
kg
);
auto
trans_data
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
EXPECT_TRUE
(
kg
->
IsInternalOutput
(
trans_data
));
// remove trans op for internal output
auto
graph_optimizer1
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager1
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
remove_internal_output_trans_op_pass
=
std
::
make_shared
<
opt
::
RemoveInternalOutputTransOp
>
();
pass_manager1
->
AddPass
(
remove_internal_output_trans_op_pass
);
graph_optimizer1
->
AddPassManager
(
pass_manager1
);
auto
new_g1
=
graph_optimizer1
->
Optimize
(
new_g
);
FuncGraphPtr
g_after1
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_single_output"
,
"after_remove_internal_output_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after1
,
new_g1
));
}
TEST_F
(
TestHWRemoveInternalOutput
,
test_remove_internal_output_trans_op_for_multiple_output
)
{
auto
kg
=
GetMutilpleOutputGraph
(
"test_remove_internal_output_trans_op_for_multiple_output"
,
"before"
);
// insert trans op for output
auto
graph_optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
insert_trans_op_pass
=
std
::
make_shared
<
opt
::
InsertTransOp
>
();
insert_trans_op_pass
->
kernel_select_
=
std
::
make_shared
<
MockRemoveInternalOutputTransOpKernelSelect
>
();
pass_manager
->
AddPass
(
insert_trans_op_pass
);
graph_optimizer
->
AddPassManager
(
pass_manager
);
auto
new_g
=
graph_optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_multiple_output"
,
"after_insert_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_g
));
auto
output_make_tuple
=
GetMakeTuple
(
kg
);
auto
make_tuple
=
output_make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
tuple_getitem
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
make_tuple1
=
tuple_getitem
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
trans_data1
=
make_tuple1
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
trans_data2
=
make_tuple1
->
cast
<
CNodePtr
>
()
->
input
(
2
);
EXPECT_TRUE
(
kg
->
IsInternalOutput
(
trans_data1
));
EXPECT_TRUE
(
kg
->
IsInternalOutput
(
trans_data2
));
// remove trans op for internal output
auto
graph_optimizer1
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager1
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
remove_internal_output_trans_op_pass
=
std
::
make_shared
<
opt
::
RemoveInternalOutputTransOp
>
();
pass_manager1
->
AddPass
(
remove_internal_output_trans_op_pass
);
graph_optimizer1
->
AddPassManager
(
pass_manager1
);
auto
new_g1
=
graph_optimizer1
->
Optimize
(
new_g
);
FuncGraphPtr
g_after1
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_multiple_output"
,
"after_remove_internal_output_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after1
,
new_g1
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py
0 → 100644
浏览文件 @
4d831966
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
operations
as
P
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
add
=
P
.
TensorAdd
()
max_pool
=
P
.
MaxPoolWithArgmax
(
padding
=
"same"
,
ksize
=
3
,
strides
=
2
)
make_tuple
=
Primitive
(
'make_tuple'
)
trans_data
=
Primitive
(
"TransData"
)
class
FnDict
:
def
__init__
(
self
):
self
.
fnDict
=
{}
def
__call__
(
self
,
fn
):
self
.
fnDict
[
fn
.
__name__
]
=
fn
def
__getitem__
(
self
,
name
):
return
self
.
fnDict
[
name
]
def
test_remove_internal_output_trans_op_for_single_output
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
x
,
y
):
res
=
add
(
x
,
y
)
return
res
@
fns
def
after_insert_trans_op
(
x
,
y
):
output
=
add
(
x
,
y
)
res
=
trans_data
(
output
)
return
make_tuple
(
res
)
@
fns
def
after_remove_internal_output_trans_op
(
x
,
y
):
res
=
add
(
x
,
y
)
return
make_tuple
(
res
)
return
fns
[
tag
]
def
test_remove_internal_output_trans_op_for_multiple_output
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
x
):
max_pool_res
=
max_pool
(
x
)
res
=
make_tuple
(
tuple_getitem
(
max_pool_res
,
0
),
tuple_getitem
(
max_pool_res
,
1
))
return
res
@
fns
def
after_insert_trans_op
(
x
):
output
=
max_pool
(
x
)
trans_data0
=
trans_data
(
tuple_getitem
(
output
,
0
))
trans_data1
=
trans_data
(
tuple_getitem
(
output
,
1
))
new_make_tuple
=
make_tuple
(
trans_data0
,
trans_data1
)
res
=
make_tuple
(
tuple_getitem
(
new_make_tuple
,
0
),
tuple_getitem
(
new_make_tuple
,
1
))
return
make_tuple
(
res
)
@
fns
def
after_remove_internal_output_trans_op
(
x
):
output
=
max_pool
(
x
)
new_make_tuple
=
make_tuple
(
tuple_getitem
(
output
,
0
),
tuple_getitem
(
output
,
1
))
res
=
make_tuple
(
tuple_getitem
(
new_make_tuple
,
0
),
tuple_getitem
(
new_make_tuple
,
1
))
return
make_tuple
(
res
)
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录