Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
048b88c4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
048b88c4
编写于
8月 31, 2020
作者:
Y
yangzhenzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update check strategy value
上级
7371cedd
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
64 addition
and
252 deletion
+64
-252
mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc
...spore/ccsrc/frontend/parallel/ops_info/activation_info.cc
+5
-26
mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc
...spore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc
+2
-9
mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc
...e/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc
+2
-6
mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc
+2
-9
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
+2
-9
mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc
.../ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc
+2
-11
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc
+2
-8
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
...pore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
+2
-17
mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc
+1
-7
mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc
...ore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc
+1
-2
mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc
...spore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc
+2
-8
mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc
+2
-7
mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc
+1
-1
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc
+2
-22
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
+19
-17
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
+1
-0
mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc
+2
-8
mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc
...re/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc
+2
-16
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc
+2
-14
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
...re/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
+2
-7
mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc
+2
-14
mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc
...ore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc
+2
-13
mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc
+2
-14
mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc
.../ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc
+2
-7
未找到文件。
mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc
浏览文件 @
048b88c4
...
@@ -30,26 +30,12 @@
...
@@ -30,26 +30,12 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
Activation
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
Activation
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
Activation
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
Status
Activation
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
return
CheckStrategyValue
(
strategy
,
inputs_shape_
);
}
}
Status
DropoutInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
DropoutInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -153,7 +139,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
...
@@ -153,7 +139,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
}
}
Status
Softmax
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
Softmax
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -229,14 +215,7 @@ Status Softmax::GetAttrs() {
...
@@ -229,14 +215,7 @@ Status Softmax::GetAttrs() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
Softmax
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
Softmax
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
Softmax
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
Softmax
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc
浏览文件 @
048b88c4
...
@@ -73,7 +73,7 @@ Strategys ExpendStrategy(const StrategyPtr &strategy) {
...
@@ -73,7 +73,7 @@ Strategys ExpendStrategy(const StrategyPtr &strategy) {
}
}
Status
ArithmeticBase
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ArithmeticBase
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -290,14 +290,7 @@ Status ArithmeticBase::InferTensorInfo() {
...
@@ -290,14 +290,7 @@ Status ArithmeticBase::InferTensorInfo() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
ArithmeticBase
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ArithmeticBase
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
ArithmeticBase
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
ArithmeticBase
::
GenerateStrategies
(
int32_t
stage_id
)
{
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
...
...
mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc
浏览文件 @
048b88c4
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
BatchParallelInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
BatchParallelInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -172,11 +172,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) {
...
@@ -172,11 +172,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) {
}
}
Status
BatchParallelInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
BatchParallelInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
return
SetCostUnderStrategyBase
(
strategy
);
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
BatchParallelInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
BatchParallelInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc
浏览文件 @
048b88c4
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
BiasAddInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
BiasAddInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -176,14 +176,7 @@ Status BiasAddInfo::InferTensorInfo() {
...
@@ -176,14 +176,7 @@ Status BiasAddInfo::InferTensorInfo() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
BiasAddInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
BiasAddInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
BiasAddInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
BiasAddInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
...
...
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
浏览文件 @
048b88c4
...
@@ -60,7 +60,7 @@ Status ConcatInfo::GetAttrs() {
...
@@ -60,7 +60,7 @@ Status ConcatInfo::GetAttrs() {
Status
ConcatInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ConcatInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
MS_EXCEPTION_IF_NULL
(
strategy
);
MS_EXCEPTION_IF_NULL
(
strategy
);
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -197,14 +197,7 @@ void ConcatInfo::ReComputeBatchSplitFlagList() {
...
@@ -197,14 +197,7 @@ void ConcatInfo::ReComputeBatchSplitFlagList() {
}
}
}
}
Status
ConcatInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ConcatInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
ConcatInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
ConcatInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
InferAttrs
()
!=
SUCCESS
)
{
if
(
InferAttrs
()
!=
SUCCESS
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc
浏览文件 @
048b88c4
...
@@ -50,11 +50,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
...
@@ -50,11 +50,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
// only check the input[0]
// only check the input[0]
Shapes
input_shape
=
{
inputs_shape_
[
0
]};
Shapes
input_shape
=
{
inputs_shape_
[
0
]};
if
(
CheckStrategyValue
(
strategy
,
input_shape
,
is_auto_parallel_
)
!=
SUCCESS
)
{
return
CheckStrategyValue
(
strategy
,
input_shape
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
DropoutDoMaskInfo
::
InferDevMatrixShape
()
{
Status
DropoutDoMaskInfo
::
InferDevMatrixShape
()
{
...
@@ -125,12 +121,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() {
...
@@ -125,12 +121,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() {
}
}
Status
DropoutDoMaskInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
DropoutDoMaskInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
return
SetCostUnderStrategyBase
(
strategy
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed"
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
DropoutDoMaskInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
DropoutDoMaskInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc
浏览文件 @
048b88c4
...
@@ -82,7 +82,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
...
@@ -82,7 +82,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
return
FAILED
;
return
FAILED
;
}
}
// Only strategy of the first input should be set.
// Only strategy of the first input should be set.
if
(
CheckStrategyValue
(
strategy
,
{
inputs_shape_
.
at
(
0
)}
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
{
inputs_shape_
.
at
(
0
)})
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -301,13 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) {
...
@@ -301,13 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
GatherV2Info
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
GatherV2Info
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
std
::
shared_ptr
<
Strategys
>
GatherV2Info
::
GenerateBatchStrategies
()
{
std
::
shared_ptr
<
Strategys
>
GatherV2Info
::
GenerateBatchStrategies
()
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
if
(
inputs_shape_
.
size
()
!=
GATHER_V2_INPUTS_SIZE
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
浏览文件 @
048b88c4
...
@@ -213,12 +213,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
...
@@ -213,12 +213,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
}
}
Status
GatherV2PInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
GatherV2PInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Invalid strategy."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
}
return
FAILED
;
return
FAILED
;
}
}
...
@@ -716,17 +711,7 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
...
@@ -716,17 +711,7 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
GatherV2PInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
GatherV2PInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Set cost under strategy failed."
;
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
}
return
FAILED
;
}
return
SUCCESS
;
}
Status
GatherV2PInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
GatherV2PInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc
浏览文件 @
048b88c4
...
@@ -240,13 +240,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) {
...
@@ -240,13 +240,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
GetNextInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
GetNextInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
GetNextInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
GetNextInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Strategys
stra
;
Strategys
stra
;
...
...
mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc
浏览文件 @
048b88c4
...
@@ -27,8 +27,7 @@
...
@@ -27,8 +27,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
L2NormalizeInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
L2NormalizeInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
INFO
)
<<
name_
<<
" : Init success."
;
return
FAILED
;
return
FAILED
;
}
}
...
...
mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc
浏览文件 @
048b88c4
...
@@ -55,7 +55,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
...
@@ -55,7 +55,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
return
FAILED
;
return
FAILED
;
}
}
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy value"
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy value"
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -207,13 +207,7 @@ Status LayerNormInfo::InferAsLossDivisor() {
...
@@ -207,13 +207,7 @@ Status LayerNormInfo::InferAsLossDivisor() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
LayerNormInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
LayerNormInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost failed"
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
LayerNormInfo
::
GenerateGammaAndBetaStrategies
(
const
std
::
vector
<
StrategyPtr
>
&
sp_vector
)
{
Status
LayerNormInfo
::
GenerateGammaAndBetaStrategies
(
const
std
::
vector
<
StrategyPtr
>
&
sp_vector
)
{
if
((
gamma_shape_
.
size
()
>
input_shape_
.
size
())
||
(
beta_shape_
.
size
()
>
input_shape_
.
size
()))
{
if
((
gamma_shape_
.
size
()
>
input_shape_
.
size
())
||
(
beta_shape_
.
size
()
>
input_shape_
.
size
()))
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc
浏览文件 @
048b88c4
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
SoftmaxCrossEntropyWithLogitsInfo
::
CheckStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
Status
SoftmaxCrossEntropyWithLogitsInfo
::
CheckStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -200,12 +200,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) {
...
@@ -200,12 +200,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) {
}
}
Status
SoftmaxCrossEntropyWithLogitsInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
SoftmaxCrossEntropyWithLogitsInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
PrintStrategy
(
strategy
);
return
SetCostUnderStrategyBase
(
strategy
);
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc
浏览文件 @
048b88c4
...
@@ -150,7 +150,7 @@ Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions
...
@@ -150,7 +150,7 @@ Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions
}
}
Status
MatMul
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
MatMul
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
...
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc
浏览文件 @
048b88c4
...
@@ -55,21 +55,7 @@ Status OneHotInfo::GetAttrs() {
...
@@ -55,21 +55,7 @@ Status OneHotInfo::GetAttrs() {
}
}
Status
OneHotInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
OneHotInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
inputs_shape_
.
size
()
!=
3
)
{
return
CheckStrategyValue
(
strategy
,
{
outputs_shape_
.
at
(
0
),
inputs_shape_
.
at
(
1
),
inputs_shape_
.
at
(
2
)});
MS_LOG
(
ERROR
)
<<
name_
<<
": inputs_shape_ size must be 3, but is "
<<
inputs_shape_
.
size
();
return
FAILED
;
}
if
(
outputs_shape_
.
size
()
!=
1
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": outputs_shape_ size must be 1, but is "
<<
outputs_shape_
.
size
();
return
FAILED
;
}
if
(
CheckStrategyValue
(
strategy
,
{
outputs_shape_
.
at
(
0
),
inputs_shape_
.
at
(
1
),
inputs_shape_
.
at
(
2
)},
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
OneHotInfo
::
InferDevMatrixShape
()
{
Status
OneHotInfo
::
InferDevMatrixShape
()
{
...
@@ -278,13 +264,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) {
...
@@ -278,13 +264,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
OneHotInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
OneHotInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
std
::
shared_ptr
<
Strategys
>
OneHotInfo
::
GenerateBatchStrategies
()
{
std
::
shared_ptr
<
Strategys
>
OneHotInfo
::
GenerateBatchStrategies
()
{
CheckGlobalDeviceManager
();
CheckGlobalDeviceManager
();
...
...
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
浏览文件 @
048b88c4
...
@@ -33,19 +33,21 @@
...
@@ -33,19 +33,21 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
CheckStrategyValue
(
const
StrategyPtr
&
strategy
,
const
Shapes
&
inputs_shape
,
bool
is_auto_parallel
)
{
Status
OperatorInfo
::
CheckStrategyValue
(
const
StrategyPtr
&
strategy
,
const
Shapes
&
inputs_shape
)
{
if
(
strategy
==
nullptr
)
{
if
(
strategy
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"
The strategy is null."
;
MS_LOG
(
ERROR
)
<<
name_
<<
":
The strategy is null."
;
return
FAILED
;
return
FAILED
;
}
}
size_t
strategy_size
=
strategy
->
GetInputNumber
();
size_t
strategy_size
=
strategy
->
GetInputNumber
();
size_t
inputs_shape_size
=
inputs_shape
.
size
();
size_t
inputs_shape_size
=
inputs_shape
.
size
();
if
(
strategy_size
!=
inputs_shape_size
)
{
if
(
strategy_size
!=
inputs_shape_size
)
{
if
(
is_auto_parallel
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
"Strategy size: "
<<
strategy_size
<<
" is not equal to inputs size: "
<<
inputs_shape_size
;
MS_LOG
(
DEBUG
)
<<
name_
<<
": Strategy size: "
<<
strategy_size
<<
" is not equal to inputs size: "
<<
inputs_shape_size
;
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"Strategy size: "
<<
strategy_size
<<
" is not equal to inputs size: "
<<
inputs_shape_size
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Strategy size: "
<<
strategy_size
<<
" is not equal to inputs size: "
<<
inputs_shape_size
;
}
}
return
FAILED
;
return
FAILED
;
}
}
...
@@ -57,11 +59,11 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
...
@@ -57,11 +59,11 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
size_t
strategy_len
=
sub_strategy
.
size
();
size_t
strategy_len
=
sub_strategy
.
size
();
size_t
inputs_len
=
sub_input_shape
.
size
();
size_t
inputs_len
=
sub_input_shape
.
size
();
if
(
strategy_len
!=
inputs_len
)
{
if
(
strategy_len
!=
inputs_len
)
{
if
(
is_auto_parallel
)
{
if
(
is_auto_parallel
_
)
{
MS_LOG
(
DEBUG
)
<<
"
Strategy len: "
<<
strategy_len
<<
" is not equal to inputs len: "
<<
inputs_len
MS_LOG
(
DEBUG
)
<<
name_
<<
":
Strategy len: "
<<
strategy_len
<<
" is not equal to inputs len: "
<<
inputs_len
<<
", index: "
<<
i
;
<<
", index: "
<<
i
;
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"
Strategy len: "
<<
strategy_len
<<
" is not equal to inputs len: "
<<
inputs_len
MS_LOG
(
ERROR
)
<<
name_
<<
":
Strategy len: "
<<
strategy_len
<<
" is not equal to inputs len: "
<<
inputs_len
<<
", index: "
<<
i
;
<<
", index: "
<<
i
;
}
}
return
FAILED
;
return
FAILED
;
...
@@ -70,29 +72,29 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
...
@@ -70,29 +72,29 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
for
(
size_t
j
=
0
;
j
<
strategy_len
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
strategy_len
;
++
j
)
{
int64_t
strategy_value
=
sub_strategy
.
at
(
j
);
int64_t
strategy_value
=
sub_strategy
.
at
(
j
);
if
(
strategy_value
<
MIN_SLICE_NUM
)
{
if
(
strategy_value
<
MIN_SLICE_NUM
)
{
if
(
is_auto_parallel
)
{
if
(
is_auto_parallel
_
)
{
MS_LOG
(
DEBUG
)
<<
"
Invalid strategy value: "
<<
strategy_value
;
MS_LOG
(
DEBUG
)
<<
name_
<<
":
Invalid strategy value: "
<<
strategy_value
;
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"
Invalid strategy value: "
<<
strategy_value
;
MS_LOG
(
ERROR
)
<<
name_
<<
":
Invalid strategy value: "
<<
strategy_value
;
}
}
return
FAILED
;
return
FAILED
;
}
}
if
((
IntToUint
(
strategy_value
)
&
IntToUint
(
strategy_value
-
1
))
!=
0
)
{
if
((
IntToUint
(
strategy_value
)
&
IntToUint
(
strategy_value
-
1
))
!=
0
)
{
if
(
is_auto_parallel
)
{
if
(
is_auto_parallel
_
)
{
MS_LOG
(
DEBUG
)
<<
"
Invalid Strategy value it is not the power of 2, "
<<
strategy_value
;
MS_LOG
(
DEBUG
)
<<
name_
<<
":
Invalid Strategy value it is not the power of 2, "
<<
strategy_value
;
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"
Invalid Strategy value it is not the power of 2, "
<<
strategy_value
;
MS_LOG
(
ERROR
)
<<
name_
<<
":
Invalid Strategy value it is not the power of 2, "
<<
strategy_value
;
}
}
return
FAILED
;
return
FAILED
;
}
}
int64_t
shape_value
=
sub_input_shape
.
at
(
j
);
int64_t
shape_value
=
sub_input_shape
.
at
(
j
);
if
((
shape_value
%
strategy_value
)
!=
0
)
{
if
((
shape_value
%
strategy_value
)
!=
0
)
{
if
(
is_auto_parallel
)
{
if
(
is_auto_parallel
_
)
{
MS_LOG
(
DEBUG
)
<<
"
Shape "
<<
shape_value
<<
" cannot be divisible by strategy "
<<
strategy_value
;
MS_LOG
(
DEBUG
)
<<
name_
<<
":
Shape "
<<
shape_value
<<
" cannot be divisible by strategy "
<<
strategy_value
;
}
else
{
}
else
{
MS_LOG
(
ERROR
)
<<
"
Shape "
<<
shape_value
<<
" cannot be divisible by strategy "
<<
strategy_value
;
MS_LOG
(
ERROR
)
<<
name_
<<
":
Shape "
<<
shape_value
<<
" cannot be divisible by strategy "
<<
strategy_value
;
}
}
return
FAILED
;
return
FAILED
;
}
}
...
...
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
浏览文件 @
048b88c4
...
@@ -176,6 +176,7 @@ class OperatorInfo {
...
@@ -176,6 +176,7 @@ class OperatorInfo {
virtual
Status
GetAttrs
()
=
0
;
virtual
Status
GetAttrs
()
=
0
;
virtual
Status
InferTensorInfo
()
=
0
;
virtual
Status
InferTensorInfo
()
=
0
;
virtual
Status
InferDevMatrixShape
()
=
0
;
virtual
Status
InferDevMatrixShape
()
=
0
;
Status
CheckStrategyValue
(
const
StrategyPtr
&
strategy
,
const
Shapes
&
inputs_shape
);
void
SetDeviceListByStrategy
();
void
SetDeviceListByStrategy
();
void
SetRepeatedCalcDevMatrix
();
void
SetRepeatedCalcDevMatrix
();
Status
CreateGroupByTensorMap
(
const
Shape
&
tensor_map
,
std
::
vector
<
Group
>
*
group
);
Status
CreateGroupByTensorMap
(
const
Shape
&
tensor_map
,
std
::
vector
<
Group
>
*
group
);
...
...
mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc
浏览文件 @
048b88c4
...
@@ -34,7 +34,7 @@ namespace parallel {
...
@@ -34,7 +34,7 @@ namespace parallel {
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
*/
*/
Status
PReLUInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
PReLUInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -220,12 +220,6 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) {
...
@@ -220,12 +220,6 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
PReLUInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
PReLUInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc
浏览文件 @
048b88c4
...
@@ -29,14 +29,7 @@
...
@@ -29,14 +29,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
ReduceMethod
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ReduceMethod
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
return
CheckStrategyValue
(
strategy
,
inputs_shape_
);
}
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
ReduceMethod
::
InferDevMatrixShape
()
{
Status
ReduceMethod
::
InferDevMatrixShape
()
{
Strategys
stra
=
strategy_
->
GetInputDim
();
Strategys
stra
=
strategy_
->
GetInputDim
();
...
@@ -354,14 +347,7 @@ Status ReduceMethod::InferTensorInfo() {
...
@@ -354,14 +347,7 @@ Status ReduceMethod::InferTensorInfo() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
ReduceMethod
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ReduceMethod
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
ReduceMethod
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
ReduceMethod
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
((
inputs_shape_
.
size
()
!=
1
)
||
(
outputs_shape_
.
size
()
!=
1
))
{
if
((
inputs_shape_
.
size
()
!=
1
)
||
(
outputs_shape_
.
size
()
!=
1
))
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc
浏览文件 @
048b88c4
...
@@ -29,14 +29,7 @@
...
@@ -29,14 +29,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
ReshapeInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
ReshapeInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
return
CheckStrategyValue
(
strategy
,
inputs_shape_
);
}
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
}
/*
/*
* support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
* support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
...
@@ -394,12 +387,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) {
...
@@ -394,12 +387,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) {
}
}
Status
ReshapeInfo
::
SetCostUnderStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
Status
ReshapeInfo
::
SetCostUnderStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
return
SetCostUnderStrategyBase
(
strategy
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
void
ReshapeInfo
::
SetCostForReshapeWithParameter
()
{
void
ReshapeInfo
::
SetCostForReshapeWithParameter
()
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
浏览文件 @
048b88c4
...
@@ -98,7 +98,7 @@ Status StridedSliceInfo::GetAttrs() {
...
@@ -98,7 +98,7 @@ Status StridedSliceInfo::GetAttrs() {
Status
StridedSliceInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
StridedSliceInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
MS_EXCEPTION_IF_NULL
(
strategy
);
MS_EXCEPTION_IF_NULL
(
strategy
);
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -232,12 +232,7 @@ std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
...
@@ -232,12 +232,7 @@ std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
}
}
Status
StridedSliceInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
StridedSliceInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
return
SetCostUnderStrategyBase
(
strategy
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
StridedSliceInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
StridedSliceInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc
浏览文件 @
048b88c4
...
@@ -67,12 +67,7 @@ Status TileInfo::GetAttrs() {
...
@@ -67,12 +67,7 @@ Status TileInfo::GetAttrs() {
Status
TileInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
TileInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Shapes
multiples
=
{
full_multiples_
};
Shapes
multiples
=
{
full_multiples_
};
if
(
CheckStrategyValue
(
strategy
,
multiples
,
is_auto_parallel_
)
!=
SUCCESS
)
{
return
CheckStrategyValue
(
strategy
,
multiples
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
TileInfo
::
InferDevMatrixShape
()
{
Status
TileInfo
::
InferDevMatrixShape
()
{
...
@@ -197,14 +192,7 @@ std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
...
@@ -197,14 +192,7 @@ std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
return
GenerateBatchStrategiesBySplitFlag
(
multiples_shape
,
split_flag_list_
);
return
GenerateBatchStrategiesBySplitFlag
(
multiples_shape
,
split_flag_list_
);
}
}
Status
TileInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
TileInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
TileInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
TileInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
InferAttrs
()
!=
SUCCESS
)
{
if
(
InferAttrs
()
!=
SUCCESS
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc
浏览文件 @
048b88c4
...
@@ -25,11 +25,7 @@
...
@@ -25,11 +25,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
TmpIdentityInfo
::
CheckStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
Status
TmpIdentityInfo
::
CheckStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
return
CheckStrategyValue
(
strategy
,
inputs_shape_
);
MS_LOG
(
ERROR
)
<<
name_
<<
": invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
TmpIdentityInfo
::
InferDevMatrixShape
()
{
Status
TmpIdentityInfo
::
InferDevMatrixShape
()
{
...
@@ -98,14 +94,7 @@ Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) {
...
@@ -98,14 +94,7 @@ Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
TmpIdentityInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
TmpIdentityInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
return
SetCostUnderStrategyBase
(
strategy
);
}
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
TmpIdentityInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
TmpIdentityInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
((
inputs_shape_
.
size
()
!=
1
)
||
(
outputs_shape_
.
size
()
!=
1
))
{
if
((
inputs_shape_
.
size
()
!=
1
)
||
(
outputs_shape_
.
size
()
!=
1
))
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc
浏览文件 @
048b88c4
...
@@ -27,14 +27,7 @@
...
@@ -27,14 +27,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
TransposeInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
TransposeInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
return
CheckStrategyValue
(
strategy
,
inputs_shape_
);
}
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
TransposeInfo
::
InferDevMatrixShape
()
{
Status
TransposeInfo
::
InferDevMatrixShape
()
{
Strategys
stra
=
strategy_
->
GetInputDim
();
Strategys
stra
=
strategy_
->
GetInputDim
();
...
@@ -195,12 +188,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
...
@@ -195,12 +188,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
}
}
Status
TransposeInfo
::
SetCostUnderStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
Status
TransposeInfo
::
SetCostUnderStrategy
(
const
mindspore
::
parallel
::
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
return
SetCostUnderStrategyBase
(
strategy
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
TransposeInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
TransposeInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
...
...
mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc
浏览文件 @
048b88c4
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
Status
VirtualDatasetInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
VirtualDatasetInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -143,12 +143,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() {
...
@@ -143,12 +143,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() {
}
}
Status
VirtualDatasetInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
VirtualDatasetInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
return
SetCostUnderStrategyBase
(
strategy
);
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
Status
VirtualDatasetInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
VirtualDatasetInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录