Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9940c723
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9940c723
编写于
8月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3980 [AutoParallel] add GatherV2P strategy analysis for W&D
Merge pull request !3980 from Chong/wd
上级
a375c50c
c68dc39d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
120 addition
and
15 deletion
+120
-15
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
.../parallel/auto_parallel/rec_core/rec_generate_strategy.cc
+117
-15
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h
...d/parallel/auto_parallel/rec_core/rec_generate_strategy.h
+3
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
浏览文件 @
9940c723
...
@@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
...
@@ -176,21 +176,102 @@ Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
s
[
axis
]
=
1
;
s
[
axis
]
=
1
;
strategies
.
push_back
(
s
);
strategies
.
push_back
(
s
);
auto
pos
=
ops
[
iter_ops
]
->
name
().
find
(
"Info"
);
return
strategies
;
auto
name
=
ops
[
iter_ops
]
->
name
().
substr
(
0
,
pos
);
}
if
(
name
==
"GatherV2"
)
{
return
strategies
;
Strategys
PrepareGatherV2P
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Dimensions
s
)
{
Strategys
strategies
;
auto
output_shape
=
ops
[
iter_ops
]
->
outputs_tensor_info
()[
0
].
shape
();
Dimensions
index
(
output_shape
.
size
()
-
1
,
0
);
for
(
size_t
i
=
0
;
i
<
index
.
size
();
i
++
)
{
index
[
i
]
=
i
;
}
}
std
::
sort
(
index
.
begin
(),
index
.
end
(),
[
&
output_shape
](
const
int
&
a
,
const
int
&
b
)
{
return
(
output_shape
[
a
+
1
]
>
output_shape
[
b
+
1
]);
});
std
::
transform
(
std
::
begin
(
index
),
std
::
end
(
index
),
std
::
begin
(
index
),
[](
int
x
)
{
return
x
+
1
;
});
index
.
insert
(
index
.
begin
(),
0
);
Dimensions
s_indices
;
Dimensions
strategie
(
output_shape
.
size
(),
1
);
for
(
size_t
i
=
0
;
i
<
ops
[
iter_ops
]
->
inputs_tensor_info
()[
1
].
shape
().
size
();
i
++
)
{
size_t
num_device
=
g_device_manager
->
DeviceNum
();
s_indices
.
push_back
(
1
);
size_t
cut
=
1
;
for
(
size_t
i
=
0
;
i
<
index
.
size
();
i
++
)
{
while
(
output_shape
[
index
[
i
]]
%
2
==
0
&&
output_shape
[
index
[
i
]]
>
0
&&
cut
<
num_device
)
{
output_shape
[
index
[
i
]]
/=
2
;
cut
*=
2
;
strategie
[
index
[
i
]]
*=
2
;
}
if
(
cut
==
num_device
)
{
break
;
}
}
auto
axis_input
=
GetValue
<
int
>
(
ops
[
iter_ops
]
->
input_value
().
at
(
2
));
if
(
axis_input
<
0
)
{
axis_input
+=
SizeToInt
(
ops
[
iter_ops
]
->
inputs_tensor_info
()[
0
].
shape
().
size
());
}
int32_t
axis
=
axis_input
;
if
(
axis
>=
SizeToInt
(
s
.
size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: GatherV2' axis out of range."
;
}
if
(
axis
==
0
)
{
s
.
clear
();
s
.
push_back
(
1
);
for
(
size_t
i
=
1
;
i
<
ops
[
iter_ops
]
->
inputs_tensor_info
()[
0
].
shape
().
size
();
i
++
)
{
s
.
push_back
(
strategie
[
ops
[
iter_ops
]
->
inputs_tensor_info
()[
1
].
shape
().
size
()
-
1
+
i
]);
}
strategies
.
push_back
(
s
);
s
.
clear
();
for
(
size_t
i
=
0
;
i
<
ops
[
iter_ops
]
->
inputs_tensor_info
()[
1
].
shape
().
size
();
i
++
)
{
s
.
push_back
(
strategie
[
i
]);
}
strategies
.
push_back
(
s
);
}
else
if
(
axis
==
1
)
{
s
.
clear
();
s
.
push_back
(
strategie
[
0
]);
s
.
push_back
(
1
);
strategies
.
push_back
(
s
);
s
.
clear
();
for
(
size_t
i
=
0
;
i
<
ops
[
iter_ops
]
->
inputs_tensor_info
()[
1
].
shape
().
size
();
i
++
)
{
s
.
push_back
(
strategie
[
ops
[
iter_ops
]
->
inputs_tensor_info
()[
0
].
shape
().
size
()
-
1
+
i
]);
}
strategies
.
push_back
(
s
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: GatherV2's axis is neither 0 nor 1."
;
}
}
strategies
.
push_back
(
s_indices
);
return
strategies
;
return
strategies
;
}
}
Dimensions
PrepareGatherV2POutputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
)
{
auto
output_shape
=
ops
[
incoming_op_index
]
->
outputs_tensor_info
()[
0
].
shape
();
Dimensions
index
(
output_shape
.
size
()
-
1
,
0
);
for
(
size_t
i
=
0
;
i
<
index
.
size
();
i
++
)
{
index
[
i
]
=
i
;
}
std
::
sort
(
index
.
begin
(),
index
.
end
(),
[
&
output_shape
](
const
int
&
a
,
const
int
&
b
)
{
return
(
output_shape
[
a
+
1
]
>
output_shape
[
b
+
1
]);
});
std
::
transform
(
std
::
begin
(
index
),
std
::
end
(
index
),
std
::
begin
(
index
),
[](
int
x
)
{
return
x
+
1
;
});
index
.
insert
(
index
.
begin
(),
0
);
Dimensions
strategie
(
output_shape
.
size
(),
1
);
size_t
num_device
=
g_device_manager
->
DeviceNum
();
size_t
cut
=
1
;
for
(
size_t
i
=
0
;
i
<
index
.
size
();
i
++
)
{
while
(
output_shape
[
index
[
i
]]
%
2
==
0
&&
output_shape
[
index
[
i
]]
>
0
&&
cut
<
num_device
)
{
output_shape
[
index
[
i
]]
/=
2
;
cut
*=
2
;
strategie
[
index
[
i
]]
*=
2
;
}
if
(
cut
==
num_device
)
{
break
;
}
}
return
strategie
;
}
Strategys
PrepareL2Normalize
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Strategys
PrepareL2Normalize
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Dimensions
s
)
{
Dimensions
s
)
{
int32_t
axis
=
0
;
int32_t
axis
=
0
;
...
@@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap
...
@@ -401,10 +482,20 @@ Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &grap
Dimensions
PrepareIncomingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
Dimensions
PrepareIncomingOperatorInputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
)
{
const
size_t
incoming_op_index
)
{
Dimensions
s
;
Dimensions
s
;
if
(
ops
[
incoming_op_index
]
->
type
()
==
RESHAPE
||
ops
[
incoming_op_index
]
->
type
()
==
GATHERV2
||
if
(
ops
[
incoming_op_index
]
->
type
()
==
RESHAPE
||
ops
[
incoming_op_index
]
->
type
()
==
TRANSPOSE
)
{
ops
[
incoming_op_index
]
->
type
()
==
TRANSPOSE
)
{
return
s
;
return
s
;
}
}
if
(
ops
[
incoming_op_index
]
->
type
()
==
GATHERV2
)
{
auto
pos
=
ops
[
incoming_op_index
]
->
name
().
find
(
"Info"
);
auto
name
=
ops
[
incoming_op_index
]
->
name
().
substr
(
0
,
pos
);
if
(
name
==
"GatherV2"
)
{
return
s
;
}
else
if
(
name
==
"GatherV2P"
)
{
return
PrepareGatherV2POutputStrategy
(
ops
,
incoming_op_index
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: Unknown type of GatherV2."
<<
std
::
endl
;
}
}
auto
strategy
=
ops
[
incoming_op_index
]
->
selected_strategy
();
auto
strategy
=
ops
[
incoming_op_index
]
->
selected_strategy
();
if
(
strategy
->
GetInputNumber
()
==
0
)
{
if
(
strategy
->
GetInputNumber
()
==
0
)
{
return
s
;
return
s
;
...
@@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, con
...
@@ -495,10 +586,13 @@ Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, con
if
(
input_value
.
back
()
->
isa
<
ValueTuple
>
())
{
if
(
input_value
.
back
()
->
isa
<
ValueTuple
>
())
{
auto
attr_axis
=
GetValue
<
std
::
vector
<
int
>>
(
input_value
.
back
());
auto
attr_axis
=
GetValue
<
std
::
vector
<
int
>>
(
input_value
.
back
());
if
(
attr_axis
.
empty
())
{
if
(
attr_axis
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: This output is a 0-D tensor."
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
input_dim
;
i
++
)
{
}
dim_list
.
push_back
(
SizeToInt
(
i
));
for
(
auto
&
axis
:
attr_axis
)
{
}
axis
<
0
?
dim_list
.
push_back
(
axis
+
SizeToInt
(
input_dim
))
:
dim_list
.
push_back
(
axis
);
}
else
{
for
(
auto
&
axis
:
attr_axis
)
{
axis
<
0
?
dim_list
.
push_back
(
axis
+
SizeToInt
(
input_dim
))
:
dim_list
.
push_back
(
axis
);
}
}
}
}
else
if
(
input_value
.
back
()
->
isa
<
Int32Imm
>
())
{
}
else
if
(
input_value
.
back
()
->
isa
<
Int32Imm
>
())
{
int
axis
=
GetValue
<
int
>
(
input_value
.
back
());
int
axis
=
GetValue
<
int
>
(
input_value
.
back
());
...
@@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
...
@@ -625,7 +719,15 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
return
PrepareBiasAdd
(
s_ptr
);
return
PrepareBiasAdd
(
s_ptr
);
}
}
if
(
ops
[
iter_ops
]
->
type
()
==
GATHERV2
)
{
if
(
ops
[
iter_ops
]
->
type
()
==
GATHERV2
)
{
return
PrepareGatherV2
(
ops
,
iter_ops
,
basic_stra
);
auto
pos
=
ops
[
iter_ops
]
->
name
().
find
(
"Info"
);
auto
name
=
ops
[
iter_ops
]
->
name
().
substr
(
0
,
pos
);
if
(
name
==
"GatherV2"
)
{
return
PrepareGatherV2
(
ops
,
iter_ops
,
basic_stra
);
}
else
if
(
name
==
"GatherV2P"
)
{
return
PrepareGatherV2P
(
ops
,
iter_ops
,
basic_stra
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Failure: Unknown type of GatherV2."
<<
std
::
endl
;
}
}
}
if
(
ops
[
iter_ops
]
->
type
()
==
L2_NORMALIZE
)
{
if
(
ops
[
iter_ops
]
->
type
()
==
L2_NORMALIZE
)
{
return
PrepareL2Normalize
(
ops
,
iter_ops
,
basic_stra
);
return
PrepareL2Normalize
(
ops
,
iter_ops
,
basic_stra
);
...
...
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h
浏览文件 @
9940c723
...
@@ -37,6 +37,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
...
@@ -37,6 +37,9 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
Strategys
PrepareOneHot
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
Strategys
PrepareOneHot
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_graph
,
const
size_t
iter_ops
);
const
size_t
iter_graph
,
const
size_t
iter_ops
);
Strategys
PrepareGatherV2
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Dimensions
s
);
Strategys
PrepareGatherV2
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Dimensions
s
);
Strategys
PrepareGatherV2P
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Dimensions
s
);
Dimensions
PrepareGatherV2POutputStrategy
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
incoming_op_index
);
Strategys
PrepareL2Normalize
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Strategys
PrepareL2Normalize
(
const
std
::
vector
<
std
::
shared_ptr
<
OperatorInfo
>>
&
ops
,
const
size_t
iter_ops
,
Dimensions
s
);
Dimensions
s
);
Strategys
MakeRecSearchStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
Strategys
MakeRecSearchStrategy
(
const
std
::
shared_ptr
<
Graph
>
&
graph
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录