Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7029bc5d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7029bc5d
编写于
6月 23, 2020
作者:
H
hongxing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix onehot axis
上级
8e20d4d8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
157 addition
and
60 deletion
+157
-60
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
.../parallel/auto_parallel/rec_core/rec_generate_strategy.cc
+119
-30
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h
...c/parallel/auto_parallel/rec_core/rec_generate_strategy.h
+17
-13
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
+2
-1
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
.../ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
+11
-11
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h
...e/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h
+8
-5
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
浏览文件 @
7029bc5d
...
...
@@ -28,10 +28,10 @@
namespace
mindspore
{
namespace
parallel
{
void
GenerateStrategy
(
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
eli_list
,
void
GenerateStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
&
eli_list
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
)
{
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
eli_list
);
MS_EXCEPTION_IF_NULL
(
index_list
);
...
...
@@ -140,10 +140,24 @@ std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &gr
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
)
{
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
=
MakeRecSearchStrategy
(
graph
,
ops
,
iter_graph
,
iter_ops
);
strategies
[
0
][
0
]
=
strategies
[
0
][
1
];
strategies
[
0
][
1
]
=
1
;
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_h
=
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_w
;
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_w
=
1.0
;
int32_t
axis
=
-
1
;
auto
iter
=
ops
[
iter_ops
]
->
attrs
().
find
(
AXIS
);
if
(
iter
!=
ops
[
iter_ops
]
->
attrs
().
end
())
{
MS_EXCEPTION_IF_NULL
(
iter
->
second
);
if
(
iter
->
second
->
isa
<
Int32Imm
>
())
{
axis
=
iter
->
second
->
cast
<
Int32ImmPtr
>
()
->
value
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
": The value of axis is not int."
;
}
}
if
(
axis
==
-
1
)
{
strategies
[
0
][
0
]
=
strategies
[
0
][
1
];
strategies
[
0
][
1
]
=
1
;
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_h
=
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_w
;
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_w
=
1.0
;
}
std
::
vector
<
int32_t
>
s_empty
=
{};
strategies
.
push_back
(
s_empty
);
strategies
.
push_back
(
s_empty
);
...
...
@@ -221,7 +235,7 @@ std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Gr
}
else
if
(
output_size
==
0
)
{
s
=
{};
}
else
{
MS_LOG
(
E
RROR
)
<<
"
Tensor's output size is unexcepted."
;
MS_LOG
(
E
XCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
":
Tensor's output size is unexcepted."
;
}
strategies
.
push_back
(
s
);
}
...
...
@@ -241,7 +255,7 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
StrategyPtr
origin_strategy
=
ops
[
iter_ops
]
->
strategy
();
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
;
size_t
max_device_num
=
g_device_manager
->
DeviceNum
();
size_t
target_tensor_batch
=
ops
[
iter_ops
]
->
out
puts_tensor_info
()[
0
].
shape
()[
0
];
size_t
target_tensor_batch
=
ops
[
iter_ops
]
->
in
puts_tensor_info
()[
0
].
shape
()[
0
];
for
(
size_t
iter_op_inputs
=
0
;
iter_op_inputs
<
ops
[
iter_ops
]
->
inputs_tensor_info
().
size
();
iter_op_inputs
++
)
{
if
(
iter_op_inputs
>=
origin_strategy
->
GetInputDim
().
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: Strategy's InputDim out of range."
;
...
...
@@ -256,8 +270,10 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
}
else
{
s
.
push_back
(
1
);
}
}
else
if
(
input_size
==
0
)
{
s
=
{};
}
else
{
MS_LOG
(
E
RROR
)
<<
"
Tensor's shape is unknown."
;
MS_LOG
(
E
XCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
":
Tensor's shape is unknown."
;
}
}
strategies
.
push_back
(
s
);
...
...
@@ -304,13 +320,13 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
}
}
void
GeneratePartitionedOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
graph
,
void
GeneratePartitionedOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
)
{
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
)
{
for
(
size_t
iter_ops
=
0
;
iter_ops
<
(
size_t
)
index_list
->
size
();
iter_ops
++
)
{
std
::
vector
<
std
::
vector
<
int32_t
>>
strategies
;
size_t
iter_graph
=
index_list
->
at
(
iter_ops
);
if
(
iter_graph
!=
SIZE_MAX
)
{
if
(
iter_graph
!=
SIZE_MAX
&&
ops
[
iter_ops
]
->
type
()
!=
GET_NEXT
)
{
strategies
=
PrepareStrategy
(
graph
,
ops
,
iter_graph
,
iter_ops
);
}
StrategyPtr
sp
=
std
::
make_shared
<
Strategy
>
(
0
,
strategies
);
...
...
@@ -335,7 +351,7 @@ size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &
return
incoming_op_index
;
}
std
::
vector
<
int32_t
>
CopyIncomingOperatorOutputStrategy
(
const
std
::
shared_ptr
<
Graph
>
graph
,
std
::
vector
<
int32_t
>
CopyIncomingOperatorOutputStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
const
size_t
iter_graph
)
{
std
::
vector
<
int32_t
>
s
;
...
...
@@ -354,8 +370,10 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
s
.
push_back
(
1
/
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_c
);
s
.
push_back
(
1
/
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_h
);
s
.
push_back
(
1
/
graph
->
nodes
[
iter_graph
].
tensor_parm
.
tensor_str
.
str_w
);
}
else
if
(
input_stra_dim
==
0
)
{
s
=
{};
}
else
{
MS_LOG
(
E
RROR
)
<<
"
Tensor's shape is unknown."
;
MS_LOG
(
E
XCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
":
Tensor's shape is unknown."
;
}
break
;
}
...
...
@@ -365,7 +383,8 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
std
::
vector
<
int32_t
>
PrepareIncomingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
)
{
std
::
vector
<
int32_t
>
s
;
if
(
ops
[
incoming_op_index
]
->
type
()
==
RESHAPE
||
ops
[
incoming_op_index
]
->
type
()
==
GATHERV2
)
{
if
(
ops
[
incoming_op_index
]
->
type
()
==
RESHAPE
||
ops
[
incoming_op_index
]
->
type
()
==
GATHERV2
||
ops
[
incoming_op_index
]
->
type
()
==
TRANSPOSE
)
{
return
s
;
}
auto
strategy
=
ops
[
incoming_op_index
]
->
selected_strategy
();
...
...
@@ -433,13 +452,23 @@ std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shar
return
s_Squeeze
;
}
bool
GetKeepDims
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
)
{
bool
keepdims
=
false
;
auto
keep_dims_iter
=
ops
[
iter_ops
]
->
attrs
().
find
(
KEEP_DIMS
);
if
(
keep_dims_iter
==
ops
[
iter_ops
]
->
attrs
().
end
())
{
MS_LOG
(
EXCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
": Don't have attr keep_dims."
;
}
MS_EXCEPTION_IF_NULL
(
keep_dims_iter
->
second
);
if
(
!
keep_dims_iter
->
second
->
isa
<
BoolImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
": Keep_dims is not a bool."
;
}
keepdims
=
keep_dims_iter
->
second
->
cast
<
BoolImmPtr
>
()
->
value
();
return
keepdims
;
}
std
::
vector
<
int32_t
>
GetDimList
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
)
{
std
::
vector
<
int32_t
>
dim_list
;
bool
keep_dims
;
if
(
!
ops
[
iter_ops
]
->
attrs
().
find
(
KEEP_DIMS
)
->
second
->
isa
<
BoolImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: Parameter keep_dims is not a boolean value."
<<
std
::
endl
;
}
keep_dims
=
ops
[
iter_ops
]
->
attrs
().
find
(
KEEP_DIMS
)
->
second
->
cast
<
BoolImmPtr
>
()
->
value
();
bool
keep_dims
=
GetKeepDims
(
ops
,
iter_ops
);
if
(
keep_dims
!=
false
)
{
return
dim_list
;
}
...
...
@@ -485,6 +514,62 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
return
s_Reduce
;
}
std
::
vector
<
int32_t
>
GetDimListFromAttrs
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
)
{
std
::
vector
<
int32_t
>
dim_list
;
auto
iter
=
ops
[
iter_ops
]
->
attrs
().
find
(
AXIS
);
if
(
iter
==
ops
[
iter_ops
]
->
attrs
().
end
())
{
MS_LOG
(
EXCEPTION
)
<<
ops
[
iter_ops
]
->
name
()
<<
": Don't have attr axis."
;
}
auto
input_dim
=
ops
[
iter_ops
]
->
inputs_tensor_info
()[
0
].
shape
().
size
();
MS_EXCEPTION_IF_NULL
(
iter
->
second
);
if
(
iter
->
second
->
isa
<
ValueTuple
>
())
{
auto
attr_axis
=
GetValue
<
std
::
vector
<
int
>>
(
iter
->
second
);
if
(
attr_axis
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
input_dim
;
++
i
)
{
dim_list
.
push_back
(
SizeToInt
(
i
));
}
}
else
{
for
(
auto
&
axis
:
attr_axis
)
{
axis
<
0
?
dim_list
.
push_back
(
axis
+
SizeToInt
(
input_dim
))
:
dim_list
.
push_back
(
axis
);
}
}
}
else
if
(
iter
->
second
->
isa
<
Int32Imm
>
())
{
int
axis
=
GetValue
<
int
>
(
iter
->
second
);
axis
<
0
?
dim_list
.
push_back
(
axis
+
SizeToInt
(
input_dim
))
:
dim_list
.
push_back
(
axis
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Axis type is invalid."
;
}
return
dim_list
;
}
std
::
vector
<
int32_t
>
ModifyStrategyIfArgIncoming
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
,
std
::
vector
<
int32_t
>
s
)
{
bool
keepdims
=
GetKeepDims
(
ops
,
incoming_op_index
);
if
(
keepdims
)
{
return
s
;
}
std
::
vector
<
int32_t
>
s_Arg
;
std
::
vector
<
int32_t
>
axis_list
;
for
(
size_t
i
=
0
;
i
<
s
.
size
();
i
++
)
{
axis_list
.
push_back
(
i
);
}
auto
dim_list
=
GetDimListFromAttrs
(
ops
,
incoming_op_index
);
for
(
auto
axis
:
dim_list
)
{
auto
it
=
find
(
axis_list
.
begin
(),
axis_list
.
end
(),
axis
);
if
(
it
==
axis_list
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: Can not find dimension indexes in Axis."
<<
std
::
endl
;
}
axis_list
.
erase
(
it
);
}
for
(
size_t
i
=
0
;
i
<
(
size_t
)
axis_list
.
size
();
i
++
)
{
s_Arg
.
push_back
(
s
[
axis_list
[
i
]]);
}
return
s_Arg
;
}
std
::
vector
<
int32_t
>
CopyIncomingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
const
size_t
incoming_op_index
)
{
std
::
vector
<
int32_t
>
s
;
...
...
@@ -497,6 +582,9 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
ops
[
incoming_op_index
]
->
type
()
==
REDUCE_MIN
||
ops
[
incoming_op_index
]
->
type
()
==
REDUCE_MEAN
)
{
s
=
ModifyStrategyIfReduceIncoming
(
ops
,
incoming_op_index
,
s
);
}
if
(
ops
[
incoming_op_index
]
->
type
()
==
ARGMAXWITHVALUE
||
ops
[
incoming_op_index
]
->
type
()
==
ARGMINWITHVALUE
)
{
s
=
ModifyStrategyIfArgIncoming
(
ops
,
incoming_op_index
,
s
);
}
}
return
s
;
}
...
...
@@ -551,11 +639,11 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
return
stra
;
}
void
GenerateEliminatedOperatorStrategyForward
(
const
std
::
shared_ptr
<
Graph
>
graph
,
void
GenerateEliminatedOperatorStrategyForward
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
no_stra_op_list
)
{
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
no_stra_op_list
)
{
if
(
no_stra_op_list
->
size
()
==
0
)
{
return
;
}
...
...
@@ -624,7 +712,8 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
std
::
vector
<
int32_t
>
s
;
if
(
ops
[
iter_ops
]
->
type
()
==
REDUCE_MAX
||
ops
[
iter_ops
]
->
type
()
==
REDUCE_MIN
||
ops
[
iter_ops
]
->
type
()
==
REDUCE_SUM
||
ops
[
iter_ops
]
->
type
()
==
REDUCE_MEAN
||
ops
[
iter_ops
]
->
type
()
==
RESHAPE
||
ops
[
iter_ops
]
->
type
()
==
GATHERV2
)
{
ops
[
iter_ops
]
->
type
()
==
GATHERV2
||
ops
[
iter_ops
]
->
type
()
==
TRANSPOSE
||
ops
[
iter_ops
]
->
type
()
==
ARGMAXWITHVALUE
||
ops
[
iter_ops
]
->
type
()
==
ARGMINWITHVALUE
)
{
return
s
;
}
...
...
@@ -656,7 +745,7 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
void
GenerateEliminatedOperatorStrategyBackward
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
no_stra_op_list
)
{
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
no_stra_op_list
)
{
if
(
no_stra_op_list
->
size
()
==
0
)
{
return
;
}
...
...
@@ -686,16 +775,16 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_pt
}
}
void
GenerateRemainingOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
graph
,
void
GenerateRemainingOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
no_stra_op_list
)
{
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
no_stra_op_list
)
{
if
(
no_stra_op_list
->
size
()
==
0
)
{
return
;
}
size_t
no_stra_op_list_size
;
size_t
no_stra_op_list_size
=
no_stra_op_list
->
size
()
;
do
{
no_stra_op_list_size
=
no_stra_op_list
->
size
();
GenerateEliminatedOperatorStrategyForward
(
graph
,
ops
,
input_tensor_names
,
index_list
,
no_stra_op_list
);
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h
浏览文件 @
7029bc5d
...
...
@@ -27,10 +27,10 @@
namespace
mindspore
{
namespace
parallel
{
void
GenerateStrategy
(
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
eli_list
,
void
GenerateStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
&
eli_list
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
);
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
);
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareMatMul
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
...
...
@@ -50,12 +50,12 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
std
::
vector
<
std
::
vector
<
int32_t
>>
PrepareStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
void
GeneratePartitionedOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
graph
,
void
GeneratePartitionedOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
);
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
);
size_t
FindIndexOfOperatorIncoming
(
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
size_t
iter_ops
);
std
::
vector
<
int32_t
>
CopyIncomingOperatorOutputStrategy
(
const
std
::
shared_ptr
<
Graph
>
graph
,
std
::
vector
<
int32_t
>
CopyIncomingOperatorOutputStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
const
size_t
iter_graph
);
std
::
vector
<
int32_t
>
PrepareIncomingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
...
...
@@ -63,19 +63,23 @@ std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std:
std
::
vector
<
int32_t
>
GetAxisList
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
int
iter_ops
);
std
::
vector
<
int32_t
>
ModifyStrategyIfSqueezeIncoming
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
,
std
::
vector
<
int32_t
>
s
);
bool
GetKeepDims
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
);
std
::
vector
<
int32_t
>
GetDimList
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
);
std
::
vector
<
int32_t
>
ModifyStrategyIfReduceIncoming
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
,
std
::
vector
<
int32_t
>
s
);
std
::
vector
<
int32_t
>
GetDimListFromAttrs
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
);
std
::
vector
<
int32_t
>
ModifyStrategyIfArgIncoming
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
,
std
::
vector
<
int32_t
>
s
);
std
::
vector
<
int32_t
>
CopyIncomingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
const
size_t
incoming_op_index
);
std
::
vector
<
std
::
vector
<
int32_t
>>
GenerateStrategiesFromStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
std
::
vector
<
int32_t
>
basic_stra
);
void
GenerateEliminatedOperatorStrategyForward
(
std
::
shared_ptr
<
Graph
>
graph
,
void
GenerateEliminatedOperatorStrategyForward
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
no_stra_op_list
);
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
no_stra_op_list
);
std
::
vector
<
int32_t
>
ModifyStrategyIfSqueezeOutgoing
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
std
::
vector
<
int32_t
>
s
);
std
::
vector
<
int32_t
>
CopyOutgoingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
...
...
@@ -83,12 +87,12 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
const
size_t
iter_ops
);
void
GenerateEliminatedOperatorStrategyBackward
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
no_stra_op_list
);
void
GenerateRemainingOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
no_stra_op_list
);
void
GenerateRemainingOperatorStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
no_stra_op_list
);
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
no_stra_op_list
);
}
// namespace parallel
}
// namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h
浏览文件 @
7029bc5d
...
...
@@ -50,7 +50,8 @@ enum OperatorType {
kRecCast
,
kRecReduce
,
kRecPReLU
,
kRecGatherV2
kRecGatherV2
,
kRecArgWithValue
};
enum
InfoType
{
kApplication
,
kConstant
};
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc
浏览文件 @
7029bc5d
...
...
@@ -163,8 +163,8 @@ size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &i
return
SIZE_MAX
;
}
void
Eliminate_Aux
(
const
size_t
node_index
,
const
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
eli_list
)
{
void
Eliminate_Aux
(
const
size_t
node_index
,
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
&
eli_list
)
{
std
::
vector
<
size_t
>
eli
;
eli
.
push_back
(
node_index
);
for
(
size_t
i
=
0
;
i
<
(
size_t
)
graph
->
nodes
[
node_index
].
node_out
.
size
();
i
++
)
{
...
...
@@ -211,18 +211,18 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph,
}
}
std
::
shared_ptr
<
Graph
>
EliminateGraph
(
const
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
eli_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
)
{
std
::
shared_ptr
<
Graph
>
EliminateGraph
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
&
eli_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
const
std
::
set
<
OperatorType
>
type_list
=
{
OperatorType
::
kRecReLU
,
OperatorType
::
kRecLog
,
OperatorType
::
kRecExp
,
OperatorType
::
kRecAdd
,
OperatorType
::
kRecElmWiseOp
,
OperatorType
::
kRecBiasAdd
,
OperatorType
::
kRecSub
,
OperatorType
::
kRecMul
,
OperatorType
::
kRecDiv
,
OperatorType
::
kRecSqueeze
,
OperatorType
::
kRecReduce
,
OperatorType
::
kRecCast
,
OperatorType
::
kRecReshape
,
OperatorType
::
kRecGatherV2
};
const
std
::
set
<
OperatorType
>
elementwise_type
=
{
OperatorType
::
kRecReLU
,
OperatorType
::
kRecLog
,
OperatorType
::
kRecExp
,
OperatorType
::
kRecAdd
,
OperatorType
::
kRecElmWiseOp
,
OperatorType
::
kRecBiasAdd
,
OperatorType
::
kRecSub
,
OperatorType
::
kRecMul
,
OperatorType
::
kRecDiv
,
OperatorType
::
kRecSqueeze
,
OperatorType
::
kRecReduce
,
OperatorType
::
kRecCast
,
OperatorType
::
kRecReshape
,
OperatorType
::
kRecGatherV2
,
OperatorType
::
kRecArgWithValue
};
for
(
size_t
node_index
=
0
;
node_index
<
(
size_t
)
graph
->
nodes
.
size
();
node_index
++
)
{
auto
type
=
graph
->
nodes
[
node_index
].
apply
.
op_type
;
if
(
type_list
.
find
(
type
)
!=
type_list
.
end
())
{
if
(
elementwise_type
.
find
(
type
)
!=
elementwise_type
.
end
())
{
Eliminate_Aux
(
node_index
,
graph
,
eli_list
);
}
}
...
...
mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h
浏览文件 @
7029bc5d
...
...
@@ -47,6 +47,8 @@ const std::map<std::string, OperatorType> DictOpType{
{
REDUCE_MIN
,
OperatorType
::
kRecReduce
},
{
REDUCE_MEAN
,
OperatorType
::
kRecReduce
},
{
GATHERV2
,
OperatorType
::
kRecGatherV2
},
{
ARGMAXWITHVALUE
,
OperatorType
::
kRecArgWithValue
},
{
ARGMINWITHVALUE
,
OperatorType
::
kRecArgWithValue
},
{
RELU
,
OperatorType
::
kRecReLU
},
{
"ReLU6"
,
OperatorType
::
kRecReLU
},
...
...
@@ -59,6 +61,7 @@ const std::map<std::string, OperatorType> DictOpType{
{
PRELU
,
OperatorType
::
kRecPReLU
},
{
TRANSPOSE
,
OperatorType
::
kRecElmWiseOp
},
{
L2_NORMALIZE
,
OperatorType
::
kRecElmWiseOp
},
{
TENSOR_ADD
,
OperatorType
::
kRecElmWiseOp
},
{
SUB
,
OperatorType
::
kRecElmWiseOp
},
...
...
@@ -123,12 +126,12 @@ void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, s
size_t
GetIndexInInputTensorNames
(
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
input_tensor_names
,
const
std
::
string
&
input_name
);
void
Eliminate_Aux
(
const
size_t
node_index
,
const
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
eli_list
);
void
Eliminate_Aux
(
const
size_t
node_index
,
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
&
eli_list
);
std
::
shared_ptr
<
Graph
>
EliminateGraph
(
const
std
::
shared_ptr
<
Graph
>
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
eli_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
index_list
);
std
::
shared_ptr
<
Graph
>
EliminateGraph
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
size_t
>>>
&
eli_list
,
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
index_list
);
}
// namespace parallel
}
// namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录