Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
188d74f1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
188d74f1
编写于
7月 15, 2020
作者:
Y
yujianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove transdata and cast for internal outputs
上级
11732f0e
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
564 addition
and
30 deletion
+564
-30
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+3
-0
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
+12
-2
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
...ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
+12
-0
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
...c/backend/optimizer/ascend/format_type/insert_trans_op.cc
+5
-11
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc
...nd/optimizer/ascend/format_type/remove_internal_output.cc
+83
-0
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h
...end/optimizer/ascend/format_type/remove_internal_output.h
+51
-0
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+31
-13
mindspore/ccsrc/backend/session/kernel_graph.h
mindspore/ccsrc/backend/session/kernel_graph.h
+3
-3
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+6
-1
tests/st/host_device/test_host_device_lenet.py
tests/st/host_device/test_host_device_lenet.py
+89
-0
tests/st/ops/cpu/test_sparse_apply_adam_op.py
tests/st/ops/cpu/test_sparse_apply_adam_op.py
+4
-0
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
+4
-0
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py
+4
-0
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
...ctivate/ascend/format_type/remove_internal_output_test.cc
+174
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py
...t/gtest_input/pre_activate/remove_internal_output_test.py
+83
-0
未找到文件。
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
188d74f1
...
...
@@ -96,6 +96,7 @@
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
...
...
@@ -199,6 +200,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
data_layout_pm
->
AddPass
(
std
::
make_shared
<
OptimizeDependence
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
TransDataSplit
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
data_layout_pm
->
AddPass
(
std
::
make_shared
<
RemoveInternalOutputTransOp
>
());
optimizer
->
AddPassManager
(
data_layout_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
@@ -220,6 +222,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
LayerNormBetaGammaBackpropFusion
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
ConvertUnSupportNodeToAICPU
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
RemoveInternalOutputCast
>
());
optimizer
->
AddPassManager
(
mixed_precision_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
浏览文件 @
188d74f1
...
...
@@ -142,6 +142,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
MS_EXCEPTION_IF_NULL
(
node
);
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
for
(
size_t
output_idx
=
0
;
output_idx
<
AnfAlgo
::
GetOutputTensorNum
(
node
);
++
output_idx
)
{
std
::
string
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
output_idx
);
if
(
output_format
==
kOpFormat_NC1KHKWHWC0
)
{
...
...
@@ -151,7 +152,11 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
auto
tuple_getitem
=
CreatTupleGetItemNode
(
func_graph
,
node
,
output_idx
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_idx
);
if
(
kCommonFormatSet
.
find
(
output_format
)
==
kCommonFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
make_tuple_inputs
.
emplace_back
(
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
false
));
auto
trans_op
=
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
false
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
trans_op
,
output_idx
,
0
);
}
make_tuple_inputs
.
emplace_back
(
trans_op
);
}
else
{
// No need insert trans op.
make_tuple_inputs
.
push_back
(
tuple_getitem
);
...
...
@@ -249,9 +254,14 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
if
(
outputs_num
==
0
)
{
return
node
;
}
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
// Single output
if
(
outputs_num
==
1
&&
(
!
AnfAlgo
::
IsTupleOutput
(
node
)))
{
return
InsertTransOpForSingleOutput
(
func_graph
,
node
,
kernel_select
);
auto
new_node
=
InsertTransOpForSingleOutput
(
func_graph
,
node
,
kernel_select
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
new_node
);
}
return
new_node
;
}
// Multiple output
return
InsertTransOpForMultipleOutput
(
func_graph
,
node
,
kernel_select
);
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
浏览文件 @
188d74f1
...
...
@@ -40,6 +40,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
AbstractBasePtrList
abstract_list
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
for
(
size_t
output_idx
=
0
;
output_idx
<
AnfAlgo
::
GetOutputTensorNum
(
cnode
);
++
output_idx
)
{
AnfNodePtr
replace_node
=
nullptr
;
const
auto
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
cnode
,
output_idx
);
...
...
@@ -64,6 +65,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL
(
replace_node
);
replace_node
->
set_scope
(
cnode
->
scope
());
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
replace_node
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
cnode
))
{
kernel_graph
->
ReplaceInternalOutput
(
cnode
,
replace_node
,
output_idx
,
0
);
}
}
else
{
replace_node
=
getitem
;
}
...
...
@@ -87,6 +91,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return
cnode
;
}
MS_EXCEPTION_IF_NULL
(
cnode
->
Type
());
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
// Single output
if
(
!
cnode
->
Type
()
->
isa
<
Tuple
>
())
{
if
(
!
need_insert_cast
[
0
])
{
...
...
@@ -109,6 +114,9 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL
(
replace_node
);
replace_node
->
set_scope
(
cnode
->
scope
());
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
replace_node
);
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
cnode
))
{
kernel_graph
->
ReplaceInternalOutput
(
cnode
,
replace_node
);
}
}
return
replace_node
;
}
...
...
@@ -188,6 +196,10 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
new_node
=
InsertCastForInput
(
func_graph
,
cnode
);
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
new_node
);
}
// process output
return
InsertCastForOutput
(
func_graph
,
new_node
,
std
::
vector
<
bool
>
(
AnfAlgo
::
GetOutputTensorNum
(
new_node
),
true
));
}
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
浏览文件 @
188d74f1
...
...
@@ -46,14 +46,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if
(
node
==
nullptr
||
!
AnfAlgo
::
IsRealKernel
(
node
))
{
return
nullptr
;
}
AnfNodePtr
front_node
;
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
MS_LOG
(
DEBUG
)
<<
"process op: "
<<
node
->
DebugString
();
AnfNodePtr
new_node
=
InsertTransOpForInput
(
func_graph
,
node
,
kernel_select_
);
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
if
(
kernel_graph
!=
nullptr
&&
kernel_graph
->
IsInternalOutput
(
node
))
{
front_node
=
kernel_graph
->
GetFrontNodeByInternalOutput
(
node
);
kernel_graph
->
ReplaceInternalOutput
(
node
,
new_
node
);
}
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
node
);
MS_LOG
(
DEBUG
)
<<
"====process op: "
<<
node
->
DebugString
();
AnfNodePtr
new_node
=
InsertTransOpForInput
(
func_graph
,
node
,
kernel_select_
);
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
&&
!
ms_context
->
enable_pynative_hook
())
{
...
...
@@ -61,12 +60,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
return
new_node
;
}
}
auto
final_node
=
InsertTransOpForOutput
(
func_graph
,
new_node
,
kernel_select_
);
if
(
kernel_graph
!=
nullptr
&&
front_node
!=
nullptr
)
{
auto
old_node
=
kernel_graph
->
GetInternalOutputByFrontNode
(
front_node
);
kernel_graph
->
ReplaceInternalOutput
(
old_node
,
final_node
);
}
return
final_node
;
return
InsertTransOpForOutput
(
func_graph
,
new_node
,
kernel_select_
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc
0 → 100644
浏览文件 @
188d74f1
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
bool
UsedForOutputOnly
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
node_users
=
manager
->
node_users
();
auto
iter
=
node_users
.
find
(
node
);
if
(
iter
==
node_users
.
end
())
{
return
false
;
}
const
auto
&
node_set
=
iter
->
second
;
for
(
const
auto
&
node_index
:
node_set
)
{
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
node_index
.
first
,
prim
::
kPrimMakeTuple
))
{
return
false
;
}
}
return
true
;
}
}
// namespace
const
BaseRef
RemoveInternalOutputTransOp
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kTransDataOpName
);
return
VectorRef
({
prim
,
X
});
}
const
BaseRef
RemoveInternalOutputCast
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
return
VectorRef
({
prim
::
kPrimCast
,
X
});
}
const
AnfNodePtr
RemoveInternalOutput
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_graph
=
func_graph
->
cast
<
KernelGraphPtr
>
();
if
(
kernel_graph
==
nullptr
)
{
return
nullptr
;
}
if
(
!
kernel_graph
->
IsInternalOutput
(
node
))
{
return
nullptr
;
}
if
(
!
UsedForOutputOnly
(
func_graph
,
node
))
{
return
nullptr
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
CheckCNodeInputSize
(
cnode
,
kTransOpInputNum
);
auto
input_node
=
cnode
->
input
(
1
);
if
(
!
AnfAlgo
::
CheckPrimitiveType
(
input_node
,
prim
::
kPrimTupleGetItem
))
{
kernel_graph
->
ReplaceInternalOutput
(
node
,
input_node
);
}
else
{
auto
tuple_getitem
=
input_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
int
idx
=
AnfAlgo
::
GetTupleGetItemOutIndex
(
tuple_getitem
);
AnfNodePtr
real_input_node
=
AnfAlgo
::
GetTupleGetItemRealInput
(
tuple_getitem
);
kernel_graph
->
ReplaceInternalOutput
(
node
,
real_input_node
,
0
,
idx
);
}
return
input_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.h
0 → 100644
浏览文件 @
188d74f1
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace
mindspore
{
namespace
opt
{
class
RemoveInternalOutput
:
public
PatternProcessPass
{
public:
explicit
RemoveInternalOutput
(
const
std
::
string
&
name
,
bool
multigraph
=
true
)
:
PatternProcessPass
(
name
,
multigraph
)
{}
~
RemoveInternalOutput
()
override
=
default
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
};
class
RemoveInternalOutputTransOp
:
public
RemoveInternalOutput
{
public:
explicit
RemoveInternalOutputTransOp
(
bool
multigraph
=
true
)
:
RemoveInternalOutput
(
"remove_internal_output_trans_op"
,
multigraph
)
{}
~
RemoveInternalOutputTransOp
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
RemoveInternalOutputCast
:
public
RemoveInternalOutput
{
public:
explicit
RemoveInternalOutputCast
(
bool
multigraph
=
true
)
:
RemoveInternalOutput
(
"remove_internal_output_cast"
,
multigraph
)
{}
~
RemoveInternalOutputCast
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
188d74f1
...
...
@@ -929,10 +929,15 @@ void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodeP
}
MS_LOG
(
INFO
)
<<
"Add internal node "
<<
node
->
DebugString
()
<<
" with front node "
<<
front_node
->
DebugString
();
front_to_internal_outputs_map_
[
front_node
]
=
node
;
internal_outputs_to_front_map_
[
node
]
=
front_node
;
int
output_idx
=
0
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
front_node
,
prim
::
kPrimTupleGetItem
))
{
output_idx
=
AnfAlgo
::
GetTupleGetItemOutIndex
(
front_node
->
cast
<
CNodePtr
>
());
}
internal_outputs_to_front_map_
[
node
][
output_idx
]
=
front_node
;
}
void
KernelGraph
::
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
)
{
void
KernelGraph
::
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
,
int
src_output_idx
,
int
dst_output_idx
)
{
if
(
new_node
==
nullptr
||
node
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"New node or node is nullptr"
;
return
;
...
...
@@ -947,9 +952,30 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
return
;
}
MS_LOG
(
INFO
)
<<
"Replace internal node "
<<
node
->
DebugString
()
<<
" To "
<<
new_node
->
DebugString
();
internal_outputs_to_front_map_
[
new_node
]
=
iter
->
second
;
front_to_internal_outputs_map_
[
iter
->
second
]
=
new_node
;
internal_outputs_to_front_map_
.
erase
(
iter
);
auto
&
front_nodes
=
iter
->
second
;
// Move all front nodes to new node mapping
if
(
src_output_idx
==
-
1
)
{
internal_outputs_to_front_map_
[
new_node
]
=
front_nodes
;
for
(
const
auto
&
front_node_iter
:
front_nodes
)
{
front_to_internal_outputs_map_
[
front_node_iter
.
second
]
=
new_node
;
}
internal_outputs_to_front_map_
.
erase
(
iter
);
return
;
}
// Move specified front node to new node mapping
int
index
=
SizeToInt
(
src_output_idx
);
auto
front_node_iter
=
front_nodes
.
find
(
index
);
if
(
front_node_iter
==
front_nodes
.
end
())
{
MS_LOG
(
INFO
)
<<
"The output "
<<
src_output_idx
<<
" of node "
<<
node
->
DebugString
()
<<
" is not an internal node"
;
return
;
}
auto
front_node
=
front_node_iter
->
second
;
internal_outputs_to_front_map_
[
new_node
][
dst_output_idx
]
=
front_node
;
front_to_internal_outputs_map_
[
front_node
]
=
new_node
;
front_nodes
.
erase
(
index
);
if
(
front_nodes
.
empty
())
{
internal_outputs_to_front_map_
.
erase
(
iter
);
}
}
AnfNodePtr
KernelGraph
::
GetInternalOutputByFrontNode
(
const
AnfNodePtr
&
front_node
)
const
{
...
...
@@ -967,14 +993,6 @@ bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
return
false
;
}
AnfNodePtr
KernelGraph
::
GetFrontNodeByInternalOutput
(
const
AnfNodePtr
&
node
)
const
{
auto
iter
=
internal_outputs_to_front_map_
.
find
(
node
);
if
(
iter
!=
internal_outputs_to_front_map_
.
end
())
{
return
iter
->
second
;
}
return
nullptr
;
}
void
KernelGraph
::
AddFinalOutputKernel
(
const
AnfNodePtr
&
node
)
{
if
(
node
==
nullptr
)
{
return
;
...
...
mindspore/ccsrc/backend/session/kernel_graph.h
浏览文件 @
188d74f1
...
...
@@ -148,10 +148,10 @@ class KernelGraph : public FuncGraph {
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
summary_nodes
()
const
{
return
summary_nodes_
;
}
void
set_summary_nodes
(
const
std
::
map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
&
nodes
)
{
summary_nodes_
=
nodes
;
}
void
AddInternalOutput
(
const
AnfNodePtr
&
front_node
,
const
AnfNodePtr
&
node
);
void
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
);
void
ReplaceInternalOutput
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
new_node
,
int
src_output_idx
=
-
1
,
int
dst_output_idx
=
-
1
);
AnfNodePtr
GetInternalOutputByFrontNode
(
const
AnfNodePtr
&
front_node
)
const
;
bool
IsInternalOutput
(
const
AnfNodePtr
&
node
)
const
;
AnfNodePtr
GetFrontNodeByInternalOutput
(
const
AnfNodePtr
&
node
)
const
;
void
AddFinalOutputKernel
(
const
AnfNodePtr
&
node
);
bool
IsFinalOutputKernel
(
const
AnfNodePtr
&
node
)
const
;
uint32_t
current_epoch
()
const
{
return
current_epoch_
;
}
...
...
@@ -223,7 +223,7 @@ class KernelGraph : public FuncGraph {
CNodePtr
end_goto_
;
bool
null_output_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
front_to_internal_outputs_map_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
internal_outputs_to_front_map_
;
std
::
unordered_map
<
AnfNodePtr
,
std
::
unordered_map
<
int
,
AnfNodePtr
>
>
internal_outputs_to_front_map_
;
std
::
set
<
AnfNodePtr
>
final_output_kernels_
;
uint32_t
current_epoch_
;
};
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
188d74f1
...
...
@@ -300,7 +300,11 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG
(
INFO
)
<<
"No corresponding internal output for output node"
;
return
;
}
auto
real_kernel
=
AnfAlgo
::
VisitKernel
(
ref_node
,
0
);
size_t
output_idx
=
0
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
out_node
,
prim
::
kPrimTupleGetItem
))
{
output_idx
=
AnfAlgo
::
GetTupleGetItemOutIndex
(
out_node
->
cast
<
CNodePtr
>
());
}
auto
real_kernel
=
AnfAlgo
::
VisitKernel
(
ref_node
,
output_idx
);
auto
ref_real_node
=
real_kernel
.
first
;
auto
ref_real_node_index
=
real_kernel
.
second
;
if
(
ref_real_node
->
isa
<
CNode
>
()
&&
node_graph
->
IsInternalOutput
(
ref_real_node
)
&&
...
...
@@ -325,6 +329,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
builder
.
SetOutputsFormat
({
format
});
d_kernel_info
->
set_select_kernel_build_info
(
builder
.
Build
());
AnfAlgo
::
SetOutputAddr
(
address
,
0
,
parameter
.
get
());
AnfAlgo
::
SetOutputInferTypeAndShape
({
type
},
{
AnfAlgo
::
GetOutputInferShape
(
parameter
,
0
)},
parameter
.
get
());
}
}
...
...
tests/st/host_device/test_host_device_lenet.py
0 → 100644
浏览文件 @
188d74f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
LeNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
batch_size
=
32
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
has_bias
=
False
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
has_bias
=
False
,
pad_mode
=
'valid'
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
reshape
=
P
.
Reshape
()
self
.
fc1
=
nn
.
Dense
(
400
,
120
)
self
.
fc1
.
matmul
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc1
.
bias_add
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc2
.
matmul
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc2
.
bias_add
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc3
=
nn
.
Dense
(
84
,
10
)
self
.
fc3
.
matmul
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
fc3
.
bias_add
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
def
construct
(
self
,
input_x
):
output
=
self
.
conv1
(
input_x
)
output
=
self
.
relu
(
output
)
output
=
self
.
pool
(
output
)
output
=
self
.
conv2
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
pool
(
output
)
output
=
self
.
reshape
(
output
,
(
self
.
batch_size
,
-
1
))
output
=
self
.
fc1
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
fc2
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
fc3
(
output
)
return
output
def
train
(
net
,
data
,
label
):
learning_rate
=
0.01
momentum
=
0.9
optimizer
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
learning_rate
,
momentum
)
criterion
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
net_with_criterion
=
WithLossCell
(
net
,
criterion
)
train_network
=
TrainOneStepCell
(
net_with_criterion
,
optimizer
)
# optimizer
train_network
.
set_train
()
res
=
train_network
(
data
,
label
)
print
(
"+++++++++Loss+++++++++++++"
)
print
(
res
)
print
(
"+++++++++++++++++++++++++++"
)
diff
=
res
.
asnumpy
()[
0
]
-
2.3025851
assert
np
.
all
(
diff
<
1.e-7
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_lenet
():
data
=
Tensor
(
np
.
ones
([
32
,
1
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
ones
([
32
]).
astype
(
np
.
int32
))
net
=
LeNet
()
train
(
net
,
data
,
label
)
tests/st/ops/cpu/test_sparse_apply_adam_op.py
浏览文件 @
188d74f1
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
@@ -43,6 +44,9 @@ class Net(nn.Cell):
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_onecard
def
test_net
():
gradient
=
Tensor
(
np
.
ones
([
3
,
3
,
3
]).
astype
(
np
.
float32
))
indices
=
Tensor
([
0
,
1
,
2
],
mstype
.
int32
)
...
...
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
浏览文件 @
188d74f1
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
@@ -35,6 +36,9 @@ class Net(nn.Cell):
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_onecard
def
test_net
():
gradient
=
Tensor
(
np
.
ones
([
3
,
3
,
3
]).
astype
(
np
.
float32
))
indices
=
Tensor
([
0
,
1
,
2
],
mstype
.
int32
)
...
...
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py
浏览文件 @
188d74f1
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
@@ -37,6 +38,9 @@ class Net(nn.Cell):
return
out
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_onecard
def
test_net
():
gradient
=
Tensor
(
np
.
ones
([
3
,
3
,
3
]).
astype
(
np
.
float32
))
indices
=
Tensor
([
0
,
1
,
2
],
mstype
.
int32
)
...
...
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
0 → 100644
浏览文件 @
188d74f1
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#undef private
#undef protected
namespace
mindspore
{
namespace
opt
{
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
class
TestHWRemoveInternalOutput
:
public
BackendCommon
{
public:
TestHWRemoveInternalOutput
()
:
getPyFun_
(
"gtest_input.pre_activate.remove_internal_output_test"
,
true
)
{}
~
TestHWRemoveInternalOutput
()
override
=
default
;
AnfNodePtr
GetMakeTuple
(
const
KernelGraphPtr
&
kg
)
{
auto
ret
=
kg
->
get_return
();
MS_EXCEPTION_IF_NULL
(
ret
);
auto
make_tuple
=
ret
->
input
(
1
);
return
make_tuple
;
}
KernelGraphPtr
GetSingleOutputGraph
(
const
std
::
string
&
func_name
,
const
std
::
string
&
sub_func_name
)
{
FuncGraphPtr
g
=
getPyFun_
.
CallAndParseRet
(
func_name
,
sub_func_name
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
{
x_abstract
,
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
make_tuple
=
GetMakeTuple
(
kg
);
auto
add
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
add
);
kg
->
AddInternalOutput
(
add
,
add
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
,
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
(),
kFloat32
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_NC1HWC0
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
()});
add
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
add
.
get
());
return
kg
;
}
KernelGraphPtr
GetMutilpleOutputGraph
(
const
std
::
string
&
func_name
,
const
std
::
string
&
sub_func_name
)
{
FuncGraphPtr
g
=
getPyFun_
.
CallAndParseRet
(
func_name
,
sub_func_name
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
{
x_abstract
};
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
output_make_tuple
=
GetMakeTuple
(
kg
);
auto
make_tuple
=
output_make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
make_tuple
);
auto
tuple_getitem1
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem1
);
auto
tuple_getitem2
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem2
);
auto
max_pool
=
tuple_getitem1
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
max_pool
);
kg
->
AddInternalOutput
(
tuple_getitem1
,
max_pool
);
kg
->
AddInternalOutput
(
tuple_getitem2
,
max_pool
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_NC1HWC0
,
kOpFormat_NC1HWC0
});
builder
.
SetOutputsDeviceType
({
kFloat16
->
type_id
(),
kFloat16
->
type_id
()});
max_pool
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
max_pool
.
get
());
return
kg
;
}
UT
::
PyFuncGraphFetcher
getPyFun_
;
};
class
MockRemoveInternalOutputTransOpKernelSelect
:
public
KernelSelect
{
public:
MockRemoveInternalOutputTransOpKernelSelect
()
=
default
;
~
MockRemoveInternalOutputTransOpKernelSelect
()
override
=
default
;
void
SelectKernel
(
const
CNodePtr
&
cnode
)
override
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_NC1HWC0
});
builder
.
SetInputsDeviceType
({
kFloat16
->
type_id
()});
builder
.
SetOutputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
()});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cnode
.
get
());
}
};
TEST_F
(
TestHWRemoveInternalOutput
,
test_remove_internal_output_trans_op_for_single_output
)
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_execution_mode
(
kGraphMode
);
auto
kg
=
GetSingleOutputGraph
(
"test_remove_internal_output_trans_op_for_single_output"
,
"before"
);
// insert trans op for output
auto
graph_optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
insert_trans_op_pass
=
std
::
make_shared
<
opt
::
InsertTransOp
>
();
insert_trans_op_pass
->
kernel_select_
=
std
::
make_shared
<
MockRemoveInternalOutputTransOpKernelSelect
>
();
pass_manager
->
AddPass
(
insert_trans_op_pass
);
graph_optimizer
->
AddPassManager
(
pass_manager
);
auto
new_g
=
graph_optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_single_output"
,
"after_insert_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_g
));
auto
make_tuple
=
GetMakeTuple
(
kg
);
auto
trans_data
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
EXPECT_TRUE
(
kg
->
IsInternalOutput
(
trans_data
));
// remove trans op for internal output
auto
graph_optimizer1
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager1
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
remove_internal_output_trans_op_pass
=
std
::
make_shared
<
opt
::
RemoveInternalOutputTransOp
>
();
pass_manager1
->
AddPass
(
remove_internal_output_trans_op_pass
);
graph_optimizer1
->
AddPassManager
(
pass_manager1
);
auto
new_g1
=
graph_optimizer1
->
Optimize
(
new_g
);
FuncGraphPtr
g_after1
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_single_output"
,
"after_remove_internal_output_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after1
,
new_g1
));
}
TEST_F
(
TestHWRemoveInternalOutput
,
test_remove_internal_output_trans_op_for_multiple_output
)
{
auto
kg
=
GetMutilpleOutputGraph
(
"test_remove_internal_output_trans_op_for_multiple_output"
,
"before"
);
// insert trans op for output
auto
graph_optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
insert_trans_op_pass
=
std
::
make_shared
<
opt
::
InsertTransOp
>
();
insert_trans_op_pass
->
kernel_select_
=
std
::
make_shared
<
MockRemoveInternalOutputTransOpKernelSelect
>
();
pass_manager
->
AddPass
(
insert_trans_op_pass
);
graph_optimizer
->
AddPassManager
(
pass_manager
);
auto
new_g
=
graph_optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_multiple_output"
,
"after_insert_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_g
));
auto
output_make_tuple
=
GetMakeTuple
(
kg
);
auto
make_tuple
=
output_make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
tuple_getitem
=
make_tuple
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
make_tuple1
=
tuple_getitem
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
trans_data1
=
make_tuple1
->
cast
<
CNodePtr
>
()
->
input
(
1
);
auto
trans_data2
=
make_tuple1
->
cast
<
CNodePtr
>
()
->
input
(
2
);
EXPECT_TRUE
(
kg
->
IsInternalOutput
(
trans_data1
));
EXPECT_TRUE
(
kg
->
IsInternalOutput
(
trans_data2
));
// remove trans op for internal output
auto
graph_optimizer1
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pass_manager1
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
remove_internal_output_trans_op_pass
=
std
::
make_shared
<
opt
::
RemoveInternalOutputTransOp
>
();
pass_manager1
->
AddPass
(
remove_internal_output_trans_op_pass
);
graph_optimizer1
->
AddPassManager
(
pass_manager1
);
auto
new_g1
=
graph_optimizer1
->
Optimize
(
new_g
);
FuncGraphPtr
g_after1
=
getPyFun_
.
CallAndParseRet
(
"test_remove_internal_output_trans_op_for_multiple_output"
,
"after_remove_internal_output_trans_op"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after1
,
new_g1
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/remove_internal_output_test.py
0 → 100644
浏览文件 @
188d74f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
operations
as
P
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
add
=
P
.
TensorAdd
()
max_pool
=
P
.
MaxPoolWithArgmax
(
padding
=
"same"
,
ksize
=
3
,
strides
=
2
)
make_tuple
=
Primitive
(
'make_tuple'
)
trans_data
=
Primitive
(
"TransData"
)
class
FnDict
:
def
__init__
(
self
):
self
.
fnDict
=
{}
def
__call__
(
self
,
fn
):
self
.
fnDict
[
fn
.
__name__
]
=
fn
def
__getitem__
(
self
,
name
):
return
self
.
fnDict
[
name
]
def
test_remove_internal_output_trans_op_for_single_output
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
x
,
y
):
res
=
add
(
x
,
y
)
return
res
@
fns
def
after_insert_trans_op
(
x
,
y
):
output
=
add
(
x
,
y
)
res
=
trans_data
(
output
)
return
make_tuple
(
res
)
@
fns
def
after_remove_internal_output_trans_op
(
x
,
y
):
res
=
add
(
x
,
y
)
return
make_tuple
(
res
)
return
fns
[
tag
]
def
test_remove_internal_output_trans_op_for_multiple_output
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
x
):
max_pool_res
=
max_pool
(
x
)
res
=
make_tuple
(
tuple_getitem
(
max_pool_res
,
0
),
tuple_getitem
(
max_pool_res
,
1
))
return
res
@
fns
def
after_insert_trans_op
(
x
):
output
=
max_pool
(
x
)
trans_data0
=
trans_data
(
tuple_getitem
(
output
,
0
))
trans_data1
=
trans_data
(
tuple_getitem
(
output
,
1
))
new_make_tuple
=
make_tuple
(
trans_data0
,
trans_data1
)
res
=
make_tuple
(
tuple_getitem
(
new_make_tuple
,
0
),
tuple_getitem
(
new_make_tuple
,
1
))
return
make_tuple
(
res
)
@
fns
def
after_remove_internal_output_trans_op
(
x
):
output
=
max_pool
(
x
)
new_make_tuple
=
make_tuple
(
tuple_getitem
(
output
,
0
),
tuple_getitem
(
output
,
1
))
res
=
make_tuple
(
tuple_getitem
(
new_make_tuple
,
0
),
tuple_getitem
(
new_make_tuple
,
1
))
return
make_tuple
(
res
)
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录