Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a27ce973
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a27ce973
编写于
6月 14, 2020
作者:
C
changzherui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert subgraph
上级
7b5b4837
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
296 addition
and
10 deletion
+296
-10
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+185
-10
mindspore/ccsrc/transform/convert.h
mindspore/ccsrc/transform/convert.h
+7
-0
mindspore/ccsrc/transform/op_adapter.h
mindspore/ccsrc/transform/op_adapter.h
+22
-0
mindspore/ccsrc/transform/op_adapter_base.h
mindspore/ccsrc/transform/op_adapter_base.h
+10
-0
mindspore/ccsrc/transform/op_declare.cc
mindspore/ccsrc/transform/op_declare.cc
+23
-0
mindspore/ccsrc/transform/op_declare.h
mindspore/ccsrc/transform/op_declare.h
+8
-0
tests/ut/python/automl/case.py
tests/ut/python/automl/case.py
+41
-0
未找到文件。
mindspore/ccsrc/transform/convert.cc
浏览文件 @
a27ce973
...
...
@@ -28,6 +28,7 @@
#include "utils/config_manager.h"
#include "utils/convert_utils.h"
#include "./common.h"
#include "utils/context/ms_context.h"
namespace
mindspore
{
namespace
transform
{
...
...
@@ -205,6 +206,7 @@ const char kNameRange[] = "Range";
const
char
kNameSquareSumAll
[]
=
"SquareSumAll"
;
const
char
kNameAscendQuant
[]
=
"AscendQuant"
;
const
char
kNameAscendDequant
[]
=
"AscendDequant"
;
const
char
kNameCase
[]
=
"Case"
;
// -----------------OpAdapter initialization--------------
std
::
unordered_map
<
std
::
string
,
OpAdapterDescPtr
>
&
DfGraphConvertor
::
get_adpt_map
()
{
...
...
@@ -411,7 +413,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameRange
),
ADPT_DESC
(
RangeD
)},
{
string
(
kNameSquareSumAll
),
ADPT_DESC
(
SquareSumAll
)},
{
string
(
kNameAscendQuant
),
ADPT_DESC
(
AscendQuant
)},
{
string
(
kNameAscendDequant
),
ADPT_DESC
(
AscendDequant
)}};
{
string
(
kNameAscendDequant
),
ADPT_DESC
(
AscendDequant
)},
{
string
(
kNameCase
),
ADPT_DESC
(
Case
)}};
#ifdef ENABLE_GE
adpt_map
[
string
(
kNamePrint
)]
=
ADPT_DESC
(
Print
);
adpt_map
[
string
(
kNameApplyAdam
)]
=
ADPT_DESC
(
ApplyAdamD
);
...
...
@@ -433,13 +436,32 @@ PrimType GetCNodeFuncType(const CNodePtr cnode) {
return
kPrimTypeUnknown
;
}
bool
IsCaseNode
(
const
CNodePtr
node
)
{
if
(
!
node
->
inputs
().
empty
()
&&
node
->
input
(
0
)
->
isa
<
CNode
>
()
&&
GetCNodeFuncName
(
node
->
input
(
0
)
->
cast
<
CNodePtr
>
())
==
"switch_layer"
)
{
return
true
;
}
return
false
;
}
std
::
string
GetCNodeTargetFuncName
(
const
CNodePtr
cnode
)
{
if
(
IsCaseNode
(
cnode
))
{
return
string
(
kNameCase
);
}
auto
name
=
GetCNodeFuncName
(
cnode
);
if
(
name
==
"switch_layer"
)
{
name
=
""
;
}
return
name
;
}
OpAdapterPtr
DfGraphConvertor
::
FindAdapter
(
const
AnfNodePtr
node
,
bool
train
)
{
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
std
::
string
name
=
kNameCustomOp
;
if
(
!
IsCustomCNode
(
cnode
))
{
name
=
GetCNodeFuncName
(
cnode
);
name
=
GetCNode
Target
FuncName
(
cnode
);
}
auto
it_adpt
=
get_adpt_map
().
find
(
name
);
...
...
@@ -957,7 +979,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
auto
c
=
anf_out
->
cast
<
CNodePtr
>
();
std
::
string
name
=
""
;
if
(
anf_out
->
isa
<
CNode
>
())
{
name
=
GetCNodeFuncName
(
c
);
name
=
GetCNode
Target
FuncName
(
c
);
}
if
(
name
==
"make_tuple"
)
{
...
...
@@ -1029,6 +1051,99 @@ void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
return
;
}
void
DfGraphConvertor
::
SetSubgraph
(
AnfNodePtr
node
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
return
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
!
IsCaseNode
(
cnode
))
{
return
;
}
std
::
vector
<
AnfNodePtr
>
case_inputs
;
for
(
size_t
i
=
1
;
i
<
cnode
->
inputs
().
size
();
i
++
)
{
case_inputs
.
emplace_back
(
cnode
->
input
(
i
));
}
std
::
shared_ptr
<
std
::
vector
<
DfGraph
>>
branches
=
std
::
make_shared
<
std
::
vector
<
DfGraph
>>
();
auto
bnode
=
cnode
->
input
(
0
)
->
cast
<
CNodePtr
>
()
->
input
(
2
)
->
cast
<
CNodePtr
>
();
for
(
size_t
i
=
1
;
i
<
bnode
->
inputs
().
size
();
i
++
)
{
auto
branch_node
=
bnode
->
input
(
i
)
->
cast
<
CNodePtr
>
();
for
(
size_t
j
=
2
;
j
<
branch_node
->
inputs
().
size
();
j
++
)
{
if
(
std
::
find
(
case_inputs
.
begin
(),
case_inputs
.
end
(),
branch_node
->
input
(
j
))
==
case_inputs
.
end
())
{
case_inputs
.
emplace_back
(
branch_node
->
input
(
j
));
}
}
}
for
(
size_t
i
=
1
;
i
<
bnode
->
inputs
().
size
();
i
++
)
{
ProcessSubgraph
(
bnode
->
input
(
i
),
case_inputs
);
}
for
(
size_t
i
=
1
;
i
<
bnode
->
inputs
().
size
();
i
++
)
{
branches
->
emplace_back
(
branches_map_
[
bnode
->
input
(
i
).
get
()]);
}
if
(
op_cache_
.
find
(
node
.
get
())
==
op_cache_
.
end
())
{
return
;
}
OpAdapterPtr
adpt
=
FindAdapter
(
node
,
training_
);
if
(
nullptr
==
adpt
)
{
MS_LOG
(
DEBUG
)
<<
"Not found adapter"
;
return
;
}
OperatorPtr
op
=
Convert
(
node
);
adpt
->
setSubgraph
(
op
,
0
,
branches
);
return
;
}
void
DfGraphConvertor
::
GetCaseNodeInput
(
const
CNodePtr
node
,
const
CNodePtr
input_node
)
{
std
::
vector
<
AnfNodePtr
>
case_inputs
;
for
(
size_t
i
=
1
;
i
<
node
->
inputs
().
size
();
i
++
)
{
case_inputs
.
emplace_back
(
node
->
input
(
i
));
}
std
::
shared_ptr
<
std
::
vector
<
DfGraph
>>
branches
=
std
::
make_shared
<
std
::
vector
<
DfGraph
>>
();
auto
bnode
=
input_node
->
input
(
2
)
->
cast
<
CNodePtr
>
();
for
(
size_t
i
=
1
;
i
<
bnode
->
inputs
().
size
();
i
++
)
{
auto
branch_node
=
bnode
->
input
(
i
)
->
cast
<
CNodePtr
>
();
for
(
size_t
j
=
2
;
j
<
branch_node
->
inputs
().
size
();
j
++
)
{
if
(
std
::
find
(
case_inputs
.
begin
(),
case_inputs
.
end
(),
branch_node
->
input
(
j
))
==
case_inputs
.
end
())
{
case_inputs
.
emplace_back
(
branch_node
->
input
(
j
));
}
}
}
const
size_t
case_index
=
1
;
const
size_t
make_tuple_index
=
2
;
AnfNodePtr
case_index_iter
=
input_node
->
input
(
case_index
);
AnfNodePtr
make_tuple_iter
=
input_node
->
input
(
make_tuple_index
);
auto
make_tuple_node
=
make_tuple_iter
->
cast
<
CNodePtr
>
();
std
::
shared_ptr
<
std
::
vector
<
OutHandler
>>
tuple_items
=
std
::
make_shared
<
std
::
vector
<
OutHandler
>>
();
for
(
size_t
i
=
0
;
i
<
case_inputs
.
size
();
i
++
)
{
auto
item
=
case_inputs
[
i
];
auto
op
=
Convert
(
item
);
if
(
op
!=
nullptr
)
{
tuple_items
->
emplace_back
(
OutHandler
(
op
,
""
));
}
else
if
(
out_handle_cache_
.
find
(
item
.
get
())
!=
out_handle_cache_
.
end
())
{
tuple_items
->
push_back
(
out_handle_cache_
[
item
.
get
()]);
}
else
{
MS_LOG
(
WARNING
)
<<
"This anf node is not supported as a case input: "
<<
item
->
ToString
();
continue
;
}
}
tuple_out_handle_cache_
[
make_tuple_node
.
get
()]
=
tuple_items
;
std
::
shared_ptr
<
std
::
vector
<
AnfNodePtr
>>
case_input_items
=
std
::
make_shared
<
std
::
vector
<
AnfNodePtr
>>
();
case_input_items
->
emplace_back
(
case_index_iter
);
case_input_items
->
emplace_back
(
make_tuple_iter
);
case_input_handle_cache_
[
node
.
get
()]
=
case_input_items
;
}
DfGraphConvertor
&
DfGraphConvertor
::
BuildGraph
()
{
SetupDatasetIterGetNextNode
(
dataset_iter_getnext_
);
...
...
@@ -1036,6 +1151,16 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
return
*
this
;
}
// Case node set input.
std
::
vector
<
AnfNodePtr
>
nodes
=
::
mindspore
::
TopoSort
(
anf_graph_
->
get_return
());
for
(
auto
&
it
:
nodes
)
{
if
(
it
->
isa
<
CNode
>
()
&&
IsCaseNode
(
it
->
cast
<
CNodePtr
>
()))
{
auto
node
=
it
->
cast
<
CNodePtr
>
();
auto
input_node
=
node
->
input
(
0
)
->
cast
<
CNodePtr
>
();
GetCaseNodeInput
(
node
,
input_node
);
}
}
// update tuple_out_handle_cache_
for
(
auto
it
:
tuple_out_handle_cache_
)
{
std
::
size_t
len
=
it
.
second
->
size
();
...
...
@@ -1056,10 +1181,11 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
// set up dependices
MS_LOG
(
DEBUG
)
<<
"set up dependices"
;
std
::
vector
<
AnfNodePtr
>
nodes
=
::
mindspore
::
TopoSort
(
anf_graph_
->
get_return
());
nodes
=
::
mindspore
::
TopoSort
(
anf_graph_
->
get_return
());
for
(
auto
&
it
:
nodes
)
{
SetNodeInput
(
it
);
SetOpControlInput
(
it
);
SetSubgraph
(
it
);
UpdateOpDesc
(
it
);
}
...
...
@@ -1075,6 +1201,18 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
inputs
.
push_back
(
*
dataset_iter_getnext_
);
}
else
{
auto
params
=
anf_graph_
->
parameters
();
if
(
use_inputs_
)
{
params
=
inputs_
;
auto
anf_params
=
anf_graph_
->
parameters
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
anf_params
.
size
();
j
++
)
{
if
(
params
[
i
]
->
ToString
()
==
anf_params
[
j
]
->
ToString
())
{
params
[
i
]
=
anf_params
[
j
];
}
}
}
}
int
index
=
0
;
for
(
auto
&
it
:
params
)
{
auto
name
=
std
::
static_pointer_cast
<
Parameter
>
(
it
)
->
name
();
...
...
@@ -1185,10 +1323,21 @@ const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNa
void
DfGraphConvertor
::
SetOpInput
(
const
OpAdapterPtr
&
adpt
,
const
CNodePtr
&
node
)
{
OperatorPtr
src
=
Convert
(
node
);
int
case_flag
=
0
;
auto
&
inputs
=
node
->
inputs
();
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
size_t
input_size
=
inputs
.
size
();
if
(
case_input_handle_cache_
.
find
(
node
.
get
())
!=
case_input_handle_cache_
.
end
())
{
case_flag
=
1
;
input_size
=
case_input_handle_cache_
[
node
.
get
()]
->
size
()
+
1
;
}
for
(
size_t
i
=
1
;
i
<
input_size
;
i
++
)
{
auto
pred
=
inputs
[
i
];
while
(
pred
->
isa
<
CNode
>
()
&&
GetCNodeFuncName
(
pred
->
cast
<
CNodePtr
>
())
==
"Depend"
)
{
if
(
case_flag
!=
0
)
{
pred
=
case_input_handle_cache_
[
node
.
get
()]
->
at
(
i
-
1
);
}
while
(
pred
->
isa
<
CNode
>
()
&&
GetCNodeTargetFuncName
(
pred
->
cast
<
CNodePtr
>
())
==
"Depend"
)
{
pred
=
pred
->
cast
<
CNodePtr
>
()
->
input
(
1
);
}
// skip the None input
...
...
@@ -1196,7 +1345,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
continue
;
}
// transform "Const" op to "Variable" op when the next node is "Assign" op.
std
::
string
c_name
=
GetCNodeFuncName
(
node
);
std
::
string
c_name
=
GetCNode
Target
FuncName
(
node
);
auto
pos
=
std
::
find
(
trans_var_list
.
begin
(),
trans_var_list
.
end
(),
c_name
);
if
(
!
training_
&&
pos
!=
trans_var_list
.
end
()
&&
pred
->
isa
<
Parameter
>
())
{
std
::
string
name
=
std
::
static_pointer_cast
<
Parameter
>
(
pred
)
->
name
();
...
...
@@ -1220,7 +1369,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
if
(
it
!=
out_handle_cache_
.
end
())
{
int
ret
=
adpt
->
setInput
(
src
,
SizeToInt
(
i
),
it
->
second
);
if
(
ret
==
0
)
{
if
(
pred
->
isa
<
CNode
>
()
&&
GetCNodeFuncName
(
pred
->
cast
<
CNodePtr
>
())
==
"tuple_getitem"
)
{
if
(
pred
->
isa
<
CNode
>
()
&&
GetCNode
Target
FuncName
(
pred
->
cast
<
CNodePtr
>
())
==
"tuple_getitem"
)
{
compute_sout_
<<
op_draw_name_
[
pred
->
cast
<
CNodePtr
>
()
->
input
(
1
).
get
()]
<<
" -> "
<<
op_draw_name_
[
node
.
get
()]
<<
":"
<<
i
<<
endl
;
}
else
if
(
pred
->
isa
<
Parameter
>
())
{
...
...
@@ -1278,6 +1427,23 @@ void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
DfGraphConvertor
::
SetOpInput
(
adpt
,
cnode
);
}
void
DfGraphConvertor
::
ProcessSubgraph
(
AnfNodePtr
node
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
if
(
!
node
->
isa
<
CNode
>
()
||
GetCNodeFuncName
(
node
->
cast
<
CNodePtr
>
())
!=
"Partial"
)
{
return
;
}
auto
graph_node
=
node
->
cast
<
CNodePtr
>
()
->
input
(
1
)
->
cast
<
ValueNodePtr
>
();
FuncGraphPtr
anf_graph
=
graph_node
->
value
()
->
cast
<
FuncGraphPtr
>
();
DfGraphConvertor
convertor
(
anf_graph
);
convertor
.
use_inputs_
=
true
;
convertor
.
inputs_
=
inputs
;
(
void
)
convertor
.
ConvertAllNode
().
BuildGraph
();
std
::
string
name
=
graph_node
->
ToString
()
+
"_ge_graph.dot"
;
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
())
{
convertor
.
DrawComputeGraph
(
name
);
}
branches_map_
[
node
.
get
()]
=
*
(
convertor
.
df_graph_
);
}
// Update GE op's shape and type info
void
DfGraphConvertor
::
UpdateOpDesc
(
const
AnfNodePtr
node
)
{
if
(
nullptr
==
node
||
!
node
->
isa
<
CNode
>
())
{
...
...
@@ -1348,6 +1514,7 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
}
}
MS_LOG
(
WARNING
)
<<
"ConvertMakeTuple: "
<<
node
.
get
()
<<
" "
<<
tuple_items
->
size
();
tuple_out_handle_cache_
[
node
.
get
()]
=
tuple_items
;
}
...
...
@@ -1711,6 +1878,14 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return
false
;
}
if
(
name
==
""
&&
GetCNodeFuncName
(
node
)
==
"switch_layer"
)
{
return
false
;
}
if
(
name
==
"Partial"
)
{
return
false
;
}
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
if
(
name
==
"make_tuple"
)
{
ConvertMakeTuple
(
node
);
...
...
@@ -1732,7 +1907,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
}
OperatorPtr
DfGraphConvertor
::
ConvertCNode
(
const
CNodePtr
node
)
{
std
::
string
name
=
GetCNodeFuncName
(
node
);
std
::
string
name
=
GetCNode
Target
FuncName
(
node
);
if
(
!
CheckCNode
(
name
,
node
))
{
return
nullptr
;
}
...
...
@@ -1879,7 +2054,7 @@ void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
}
compute_sout_
<<
"<tr><td colspan=
\"
"
<<
(
input_map
.
size
()
+
dyn_input_map
.
size
())
<<
"
\"
>
\"
"
<<
node
->
ToString
()
<<
":"
<<
GetCNodeFuncName
(
node
)
<<
"
\"
</td></tr>"
<<
endl
;
<<
":"
<<
GetCNode
Target
FuncName
(
node
)
<<
"
\"
</td></tr>"
<<
endl
;
// print attrs' values
auto
atts
=
adpt
->
GetAttrsFromDrawGraph
();
...
...
mindspore/ccsrc/transform/convert.h
浏览文件 @
a27ce973
...
...
@@ -201,6 +201,7 @@ class DfGraphConvertor {
OperatorPtr
ConvertParameter
(
AnfNodePtr
node
);
Status
TryConvertValueNodeToMultiConst
(
const
ValueNodePtr
node
);
OperatorPtr
ConvertValueNode
(
ValueNodePtr
node
);
void
GetCaseNodeInput
(
const
CNodePtr
node
,
const
CNodePtr
input_node
);
void
ConvertTupleGetItem
(
const
CNodePtr
node
);
void
GetDependOnParameterUse
(
const
CNodePtr
&
node
,
const
AnfNodePtr
&
src_node
,
const
AnfNodePtr
&
dest_node
,
const
std
::
shared_ptr
<
std
::
vector
<
OperatorPtr
>>
&
src_ops_list
,
...
...
@@ -217,6 +218,8 @@ class DfGraphConvertor {
void
SetNodeInput
(
AnfNodePtr
node
);
void
SetOpControlInput
(
const
AnfNodePtr
node
);
void
UpdateOpDesc
(
AnfNodePtr
node
);
void
SetSubgraph
(
AnfNodePtr
node
);
void
ProcessSubgraph
(
AnfNodePtr
node
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs
);
void
BuildSaveCheckpointGraph
();
void
DrawCNode
(
const
CNodePtr
node
,
const
OpAdapterPtr
adpt
);
void
UpdateDataOpDesc
(
const
AnfNodePtr
&
it
,
const
OperatorPtr
&
op
)
const
;
...
...
@@ -228,22 +231,26 @@ class DfGraphConvertor {
std
::
shared_ptr
<
DfGraph
>
save_ckp_graph_
{
nullptr
};
std
::
shared_ptr
<
DfGraph
>
restore_ckp_graph_
{
nullptr
};
std
::
shared_ptr
<
DfGraph
>
broadcast_graph_
{
nullptr
};
std
::
unordered_map
<
AnfNode
*
,
DfGraph
>
branches_map_
;
std
::
unordered_map
<
AnfNode
*
,
OperatorPtr
>
op_cache_
;
std
::
unordered_map
<
AnfNode
*
,
std
::
vector
<
ControlEdge
>>
control_depend_cache_
;
/* record "tuple_getitem"<->"out_handler" mapping */
std
::
unordered_map
<
AnfNode
*
,
OutHandler
>
out_handle_cache_
;
/* record "make_tuple"<->"out_handler vector" mapping */
std
::
unordered_map
<
AnfNode
*
,
std
::
shared_ptr
<
std
::
vector
<
OutHandler
>>>
tuple_out_handle_cache_
;
std
::
unordered_map
<
AnfNode
*
,
std
::
shared_ptr
<
std
::
vector
<
AnfNodePtr
>>>
case_input_handle_cache_
;
std
::
unordered_map
<
std
::
string
,
AnfNodePtr
>
params_
;
std
::
unordered_map
<
std
::
string
,
OperatorPtr
>
vars_
;
std
::
vector
<
std
::
pair
<
ge
::
Operator
,
std
::
string
>>
graph_outputs_
;
std
::
vector
<
OperatorPtr
>
graph_const_inputs_
;
std
::
vector
<
OperatorPtr
>
init_ops_
;
std
::
vector
<
OperatorPtr
>
broadcast_ops_
;
std
::
vector
<
AnfNodePtr
>
inputs_
;
OperatorPtr
dataset_iter_getnext_
;
Status
error_
=
SUCCESS
;
bool
training_
=
false
;
bool
distribute_
=
false
;
bool
use_inputs_
=
false
;
};
}
// namespace transform
}
// namespace mindspore
...
...
mindspore/ccsrc/transform/op_adapter.h
浏览文件 @
a27ce973
...
...
@@ -164,6 +164,25 @@ class OpAdapter : public BaseOpAdapter {
const
std
::
unordered_map
<
unsigned
int
,
AttrDesc
>
&
getInputAttrMap
()
override
{
return
input_attr_map_
;
}
const
std
::
unordered_map
<
int
,
DynInputDesc
>
&
getDynInputMap
()
override
{
return
dyn_input_map_
;
}
const
std
::
unordered_map
<
int
,
OutputDesc
>
&
getOutputMap
()
override
{
return
output_map_
;
}
const
std
::
unordered_map
<
int
,
DynSubGraphDesc
>
&
getDynSubgraphMap
()
override
{
return
dyn_subgraph_map_
;
}
Status
SetOpSubgraphFunc
(
const
OperatorPtr
&
op
,
int
index
,
std
::
shared_ptr
<
std
::
vector
<
DfGraph
>>
branches
)
{
MS_EXCEPTION_IF_NULL
(
op
);
auto
it
=
dyn_subgraph_map_
.
find
(
index
);
if
(
it
!=
dyn_subgraph_map_
.
end
())
{
auto
size
=
branches
->
size
();
it
->
second
.
create_dyn_subgraph
(
op
,
static_cast
<
unsigned
int
>
(
size
));
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
it
->
second
.
set_subgraph
(
op
,
static_cast
<
unsigned
int
>
(
i
),
std
::
make_shared
<
DfGraph
>
((
*
branches
)[
i
]));
}
return
SUCCESS
;
}
return
NOT_FOUND
;
}
int
setSubgraph
(
const
OperatorPtr
&
op
,
int
index
,
std
::
shared_ptr
<
std
::
vector
<
DfGraph
>>
branches
)
override
{
return
static_cast
<
int
>
(
SetOpSubgraphFunc
(
op
,
index
,
branches
));
}
Status
SetCustomOpInput
(
const
CusOperatorPtr
&
op
,
int
index
,
const
OperatorPtr
&
input
)
{
MS_EXCEPTION_IF_NULL
(
op
);
...
...
@@ -855,6 +874,7 @@ class OpAdapter : public BaseOpAdapter {
static
const
std
::
unordered_map
<
int
,
DynInputDesc
>
dyn_input_map_
;
static
const
std
::
unordered_map
<
int
,
OutputDesc
>
output_map_
;
static
const
std
::
unordered_map
<
int
,
DynOutputDesc
>
dyn_output_map_
;
static
const
std
::
unordered_map
<
int
,
DynSubGraphDesc
>
dyn_subgraph_map_
;
static
const
std
::
unordered_map
<
std
::
string
,
AttrDesc
>
attr_map_
;
static
const
std
::
unordered_map
<
std
::
string
,
int
>
enum_map_
;
// convert input from anf graph to Attr in Operators
...
...
@@ -874,6 +894,8 @@ const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_;
template
<
typename
T
>
const
std
::
unordered_map
<
int
,
DynOutputDesc
>
OpAdapter
<
T
>::
dyn_output_map_
;
template
<
typename
T
>
const
std
::
unordered_map
<
int
,
DynSubGraphDesc
>
OpAdapter
<
T
>::
dyn_subgraph_map_
;
template
<
typename
T
>
const
std
::
unordered_map
<
std
::
string
,
AttrDesc
>
OpAdapter
<
T
>::
attr_map_
;
template
<
typename
T
>
const
std
::
unordered_map
<
std
::
string
,
int
>
OpAdapter
<
T
>::
enum_map_
;
...
...
mindspore/ccsrc/transform/op_adapter_base.h
浏览文件 @
a27ce973
...
...
@@ -88,6 +88,8 @@ using DynInputOpFunc = std::function<void(OperatorPtr, unsigned int, OperatorPtr
using
DynInputHandleFunc
=
std
::
function
<
void
(
OperatorPtr
,
unsigned
int
,
OutHandler
)
>
;
using
UpdateOutputDescFunc
=
std
::
function
<
void
(
OperatorPtr
,
GeTensorDesc
)
>
;
using
CreateDynOutputOpFunc
=
std
::
function
<
void
(
OperatorPtr
,
unsigned
int
)
>
;
using
CreateDynSubGraphFunc
=
std
::
function
<
void
(
OperatorPtr
,
unsigned
int
)
>
;
using
DynSubGraphFunc
=
std
::
function
<
void
(
OperatorPtr
,
unsigned
int
,
DfGraphPtr
)
>
;
struct
AttrDesc
{
std
::
string
name
;
...
...
@@ -108,6 +110,12 @@ struct DynInputDesc {
DynInputHandleFunc
set_handle
;
};
struct
DynSubGraphDesc
{
std
::
string
name
;
CreateDynSubGraphFunc
create_dyn_subgraph
;
DynSubGraphFunc
set_subgraph
;
};
struct
OutputDesc
{
std
::
string
name
;
UpdateOutputDescFunc
update_out_desc
;
...
...
@@ -123,6 +131,7 @@ class BaseOpAdapter {
virtual
~
BaseOpAdapter
()
{}
virtual
OperatorPtr
generate
(
const
AnfNodePtr
&
anf
)
=
0
;
virtual
OperatorPtr
generate
(
const
std
::
string
&
type
)
{
return
std
::
make_shared
<
ge
::
Operator
>
(
type
);
}
virtual
int
setSubgraph
(
const
OperatorPtr
&
op
,
int
index
,
std
::
shared_ptr
<
std
::
vector
<
DfGraph
>>
branches
)
=
0
;
virtual
int
setInput
(
const
OperatorPtr
&
op
,
int
index
,
const
OperatorPtr
&
input
)
=
0
;
virtual
int
setInput
(
const
OperatorPtr
&
op
,
int
index
,
const
OutHandler
&
handle
)
=
0
;
virtual
int
setInput
(
const
OperatorPtr
&
op
,
int
index
,
...
...
@@ -146,6 +155,7 @@ class BaseOpAdapter {
virtual
const
std
::
unordered_map
<
unsigned
int
,
AttrDesc
>
&
getInputAttrMap
()
=
0
;
virtual
const
std
::
unordered_map
<
int
,
DynInputDesc
>
&
getDynInputMap
()
=
0
;
virtual
const
std
::
unordered_map
<
int
,
OutputDesc
>
&
getOutputMap
()
=
0
;
virtual
const
std
::
unordered_map
<
int
,
DynSubGraphDesc
>
&
getDynSubgraphMap
()
=
0
;
void
AddAttrToDrawGraph
(
const
std
::
string
&
attr_str
)
{
attrs_vec_
.
push_back
(
attr_str
);
}
const
std
::
vector
<
std
::
string
>
&
GetAttrsFromDrawGraph
()
const
{
return
attrs_vec_
;
}
void
clearAttrVect
()
{
attrs_vec_
.
clear
();
}
...
...
mindspore/ccsrc/transform/op_declare.cc
浏览文件 @
a27ce973
...
...
@@ -64,6 +64,22 @@ namespace transform {
} \
}
#define DYN_SUBGRAPH_MAP(T) \
template <> \
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_
#define DYN_SUBGRAPH_DESC(name) \
{ \
#name, \
[](const OperatorPtr op, unsigned int num) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->create_dynamic_subgraph_##name(num); \
}, \
[](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \
auto p = std::static_pointer_cast<OpType>(op); \
(void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \
} \
}
#define ATTR_MAP(T) \
template <> \
const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_
...
...
@@ -841,6 +857,13 @@ INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}};
ATTR_MAP
(
Cast
)
=
EMPTY_ATTR_MAP
;
OUTPUT_MAP
(
Cast
)
=
{{
0
,
OUTPUT_DESC
(
y
)}};
// Case
INPUT_MAP
(
Case
)
=
{{
1
,
INPUT_DESC
(
branch_index
)}};
DYN_INPUT_MAP
(
Case
)
=
{{
2
,
DYN_INPUT_DESC
(
input
)}};
ATTR_MAP
(
Case
)
=
EMPTY_ATTR_MAP
;
DYN_OUTPUT_MAP
(
Case
)
=
{{
0
,
DYN_OUTPUT_DESC
(
output
)}};
DYN_SUBGRAPH_MAP
(
Case
)
=
{{
0
,
DYN_SUBGRAPH_DESC
(
branches
)}};
// Reciprocal
INPUT_MAP
(
Reciprocal
)
=
{{
1
,
INPUT_DESC
(
x
)}};
ATTR_MAP
(
Reciprocal
)
=
EMPTY_ATTR_MAP
;
...
...
mindspore/ccsrc/transform/op_declare.h
浏览文件 @
a27ce973
...
...
@@ -46,6 +46,10 @@ namespace transform {
template <> \
const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \
template <> \
const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_;
#define DECLARE_OP_USE_DYN_OUTPUT(T) \
template <> \
const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
...
...
@@ -232,6 +236,10 @@ DECLARE_OP_USE_OUTPUT(RealDiv)
DECLARE_OP_ADAPTER
(
Cast
)
DECLARE_OP_USE_INPUT_ATTR
(
Cast
)
DECLARE_OP_USE_OUTPUT
(
Cast
)
DECLARE_OP_ADAPTER
(
Case
)
DECLARE_OP_USE_DYN_INPUT
(
Case
)
DECLARE_OP_USE_DYN_SUBGRAPH
(
Case
)
DECLARE_OP_USE_DYN_OUTPUT
(
Case
)
DECLARE_OP_ADAPTER
(
Reciprocal
)
DECLARE_OP_USE_OUTPUT
(
Reciprocal
)
DECLARE_OP_ADAPTER
(
Neg
)
...
...
tests/ut/python/automl/case.py
0 → 100644
浏览文件 @
a27ce973
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test case."""
import
numpy
as
np
import
mindspore
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
context
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
3
,
3
)
self
.
conv2
=
nn
.
Conv2d
(
1
,
3
,
5
,
has_bias
=
True
)
self
.
layers
=
(
self
.
conv1
,
self
.
conv2
)
def
construct
(
self
,
x
,
index
):
x
=
self
.
layers
[
index
](
x
)
y
=
self
.
conv1
(
x
)
return
x
+
y
def
test_case
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
Net
()
data
=
Tensor
(
np
.
ones
((
1
,
1
,
224
,
224
)),
mindspore
.
float32
)
idx
=
Tensor
(
1
,
mindspore
.
int32
)
net
(
data
,
idx
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录