Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
390a86ef
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
390a86ef
编写于
5月 22, 2020
作者:
L
lichenever
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gatherv2
上级
9f079d44
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
126 addition
and
46 deletion
+126
-46
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
+116
-36
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
+1
-1
tests/ut/python/parallel/test_gather_v2.py
tests/ut/python/parallel/test_gather_v2.py
+7
-7
tests/ut/python/parallel/test_gather_v2_primitive.py
tests/ut/python/parallel/test_gather_v2_primitive.py
+2
-2
未找到文件。
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
浏览文件 @
390a86ef
...
@@ -48,7 +48,7 @@ Status GatherV2PInfo::GetAttrs() {
...
@@ -48,7 +48,7 @@ Status GatherV2PInfo::GetAttrs() {
}
}
Status
GatherV2PInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
Status
GatherV2PInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
CheckStrategyValue
(
strategy
,
{
inputs_shape_
.
at
(
0
)}
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
if
(
is_auto_parallel_
)
{
if
(
is_auto_parallel_
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": Invalid strategy."
;
MS_LOG
(
DEBUG
)
<<
name_
<<
": Invalid strategy."
;
}
else
{
}
else
{
...
@@ -84,12 +84,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
...
@@ -84,12 +84,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return
FAILED
;
return
FAILED
;
}
}
// Don't support repeated calc
// param_strategy(axis) != 1, index can't be splited
auto
params_strategy
=
strategy
->
GetInputDim
().
at
(
0
);
auto
index_strategy
=
strategy
->
GetInputDim
().
at
(
1
);
auto
product_i
=
std
::
accumulate
(
index_strategy
.
begin
(),
index_strategy
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
((
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
)
&&
(
product_i
!=
1
))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": param is splited at dim (axis)"
<<
axis_
<<
" ,index can't be splited."
;
return
FAILED
;
}
// param_strategy(axis) != 1, Don't support repeated calc
CheckGlobalDeviceManager
();
CheckGlobalDeviceManager
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
auto
product
=
std
::
accumulate
(
params_strategy
.
begin
(),
params
_strategy
.
end
(),
1
,
std
::
multiplies
<
int
>
());
auto
product
_p
=
std
::
accumulate
(
param_strategy
.
begin
(),
param
_strategy
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
dev_num
!=
IntToSize
(
product
)
)
{
if
(
IntToSize
(
product_p
)
!=
dev_num
&&
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy. Don't support repeated calc."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy. Don't support repeated calc."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -97,26 +104,66 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
...
@@ -97,26 +104,66 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
GatherV2PInfo
::
InferMirrorOps
()
{
mirror_ops_
.
clear
();
Shape
input_a_tensor_map
=
inputs_tensor_map_
.
at
(
0
);
std
::
vector
<
Group
>
input_a_group
;
if
(
CreateGroupByTensorMap
(
input_a_tensor_map
,
&
input_a_group
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Create group for input a failed."
;
return
FAILED
;
}
OperatorVector
op_for_input_a
,
op_for_input_b
,
op_for_axis
;
if
(
input_a_group
.
empty
())
{
MS_LOG
(
INFO
)
<<
name_
<<
" : The mirror group is empty."
;
return
SUCCESS
;
}
else
{
op_for_input_a
=
CreateMirrorOps
(
input_a_group
[
0
].
name
(),
input_a_group
[
0
].
GetDevNum
());
MS_LOG
(
INFO
)
<<
name_
<<
" : Create the mirror ops for input a success, group is "
<<
input_a_group
[
0
].
name
();
}
mirror_ops_
.
push_back
(
op_for_input_a
);
mirror_ops_
.
push_back
(
op_for_input_b
);
mirror_ops_
.
push_back
(
op_for_axis
);
return
SUCCESS
;
}
Status
GatherV2PInfo
::
InferDevMatrixShape
()
{
Status
GatherV2PInfo
::
InferDevMatrixShape
()
{
dev_matrix_shape_
.
clear
();
dev_matrix_shape_
.
clear
();
out_dev_matrix_shape_
.
clear
();
out_dev_matrix_shape_
.
clear
();
// infer input dev_matrix_shape
// infer input dev_matrix_shape
auto
params_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
auto
param_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
dev_matrix_shape_
=
params_strategy
;
auto
index_strategy
=
strategy_
->
GetInputDim
().
at
(
1
);
dev_matrix_shape_
=
param_strategy
;
// param_strategy(axis)!=1,
if
(
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
)
{
std
::
reverse
(
dev_matrix_shape_
.
begin
(),
dev_matrix_shape_
.
end
());
}
else
{
dev_matrix_shape_
.
insert
(
dev_matrix_shape_
.
end
(),
index_strategy
.
begin
(),
index_strategy
.
end
());
}
// infer out dev_matrix_shape
// infer out dev_matrix_shape
// axis!=0, split axis
// axis!=0, split axis
if
(
axis_
!=
0
&&
param
s
_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
)
{
if
(
axis_
!=
0
&&
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
)
{
out_dev_matrix_shape_
.
push_back
(
param
s_strategy
.
at
(
0
)
*
params
_strategy
.
at
(
IntToSize
(
axis_
)));
out_dev_matrix_shape_
.
push_back
(
param
_strategy
.
at
(
0
)
*
param
_strategy
.
at
(
IntToSize
(
axis_
)));
for
(
size_t
i
=
1
;
i
<
param
s
_strategy
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
param_strategy
.
size
();
++
i
)
{
if
(
i
==
IntToSize
(
axis_
))
{
if
(
i
==
IntToSize
(
axis_
))
{
out_dev_matrix_shape_
.
push_back
(
1
);
out_dev_matrix_shape_
.
push_back
(
1
);
}
else
{
}
else
{
out_dev_matrix_shape_
.
push_back
(
param
s
_strategy
.
at
(
i
));
out_dev_matrix_shape_
.
push_back
(
param_strategy
.
at
(
i
));
}
}
}
}
}
else
{
}
else
{
out_dev_matrix_shape_
=
params_strategy
;
out_dev_matrix_shape_
=
dev_matrix_shape_
;
}
auto
product_out
=
std
::
accumulate
(
out_dev_matrix_shape_
.
begin
(),
out_dev_matrix_shape_
.
end
(),
1
,
std
::
multiplies
<
int
>
());
CheckGlobalDeviceManager
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
if
(
product_out
==
1
)
{
out_dev_matrix_shape_
.
insert
(
out_dev_matrix_shape_
.
begin
(),
dev_num
);
}
}
return
SUCCESS
;
return
SUCCESS
;
...
@@ -124,28 +171,56 @@ Status GatherV2PInfo::InferDevMatrixShape() {
...
@@ -124,28 +171,56 @@ Status GatherV2PInfo::InferDevMatrixShape() {
Status
GatherV2PInfo
::
InferTensorMap
()
{
Status
GatherV2PInfo
::
InferTensorMap
()
{
// infer input tensor map
// infer input tensor map
// param_strategy(axis) != 1
size_t
param_size
=
inputs_shape_
.
at
(
0
).
size
();
size_t
param_size
=
inputs_shape_
.
at
(
0
).
size
();
size_t
index_size
=
inputs_shape_
.
at
(
1
).
size
();
size_t
index_size
=
inputs_shape_
.
at
(
1
).
size
();
std
::
vector
<
int32_t
>
tensor_map_index
(
index_size
,
-
1
);
size_t
total_size
=
dev_matrix_shape_
.
size
();
std
::
vector
<
int32_t
>
tensor_map_index
;
std
::
vector
<
int32_t
>
tensor_map_params
;
std
::
vector
<
int32_t
>
tensor_map_params
;
for
(
size_t
i
=
0
;
i
<
param_size
;
++
i
)
{
auto
param_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
tensor_map_params
.
push_back
(
SizeToInt
(
param_size
-
i
-
1
));
if
(
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
)
{
tensor_map_index
.
insert
(
tensor_map_index
.
begin
(),
index_size
,
-
1
);
for
(
size_t
i
=
0
;
i
<
param_size
;
++
i
)
{
tensor_map_params
.
push_back
(
SizeToInt
(
i
));
}
}
else
{
// param_strategy(axis) == 1
for
(
size_t
i
=
0
;
i
<
param_size
;
++
i
)
{
tensor_map_params
.
push_back
(
SizeToInt
(
total_size
-
i
-
1
));
}
for
(
size_t
i
=
0
;
i
<
index_size
;
++
i
)
{
tensor_map_index
.
push_back
(
SizeToInt
(
index_size
-
i
-
1
));
}
}
}
// infer output tensor map
// infer output tensor map
std
::
vector
<
int32_t
>
tensor_map_out
;
std
::
vector
<
int32_t
>
tensor_map_out
;
if
(
axis_
==
0
)
{
if
(
param_strategy
.
at
(
IntToSize
(
axis_
))
==
1
)
{
tensor_map_out
.
push_back
(
SizeToInt
(
param_size
-
1
));
// param_strategy(axis) == 1
tensor_map_out
.
insert
(
tensor_map_out
.
end
(),
index_size
-
1
,
-
1
);
for
(
size_t
i
=
1
;
i
<
param_size
;
++
i
)
{
tensor_map_out
.
push_back
(
SizeToInt
(
param_size
-
i
-
1
));
}
}
else
{
for
(
size_t
i
=
0
;
i
<
param_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_size
;
++
i
)
{
if
(
i
==
IntToSize
(
axis_
))
{
if
(
i
==
IntToSize
(
axis_
))
{
tensor_map_out
.
insert
(
tensor_map_out
.
end
(),
index_size
,
-
1
);
for
(
size_t
j
=
0
;
j
<
index_size
;
++
j
)
{
tensor_map_out
.
push_back
(
SizeToInt
(
index_size
-
j
-
1
));
}
}
else
{
}
else
{
tensor_map_out
.
push_back
(
SizeToInt
(
param_size
-
i
-
1
));
tensor_map_out
.
push_back
(
SizeToInt
(
total_size
-
i
-
1
));
}
}
}
else
{
// param_strategy(axis) != 1
if
(
axis_
==
0
)
{
tensor_map_out
.
insert
(
tensor_map_out
.
end
(),
0
);
tensor_map_out
.
insert
(
tensor_map_out
.
end
(),
index_size
-
1
,
-
1
);
for
(
size_t
i
=
1
;
i
<
param_size
;
++
i
)
{
tensor_map_out
.
push_back
(
i
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
param_size
;
++
i
)
{
if
(
i
==
IntToSize
(
axis_
))
{
tensor_map_out
.
insert
(
tensor_map_out
.
end
(),
index_size
,
-
1
);
}
else
{
tensor_map_out
.
push_back
(
SizeToInt
(
param_size
-
i
-
1
));
}
}
}
}
}
}
}
...
@@ -209,7 +284,12 @@ Status GatherV2PInfo::InferBias() {
...
@@ -209,7 +284,12 @@ Status GatherV2PInfo::InferBias() {
Status
GatherV2PInfo
::
InferGroup
()
{
Status
GatherV2PInfo
::
InferGroup
()
{
std
::
vector
<
Group
>
group_list
;
std
::
vector
<
Group
>
group_list
;
if
(
CreateGroupByDim
(
IntToSize
(
axis_
),
&
group_list
)
!=
SUCCESS
)
{
auto
param_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
size_t
dim
=
IntToSize
(
axis_
);
if
(
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
&&
inputs_shape_
.
at
(
0
).
size
()
==
2
)
{
dim
=
(
axis_
+
1
)
%
2
;
}
if
(
CreateGroupByDim
(
dim
,
&
group_list
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Create group failed."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Create group failed."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -231,7 +311,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
...
@@ -231,7 +311,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto
sub
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
SUB
),
gen_g
.
virtual_input_node
(),
CreateInt32Tensor
(
bias_
)});
auto
sub
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
SUB
),
gen_g
.
virtual_input_node
(),
CreateInt32Tensor
(
bias_
)});
auto
relu
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
RELU
),
sub
});
auto
relu
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
RELU
),
sub
});
auto
minimum
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
MINIMUM
),
relu
,
CreateInt32Tensor
(
slice_size_
-
1
)});
auto
minimum
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
MINIMUM
),
relu
,
CreateInt32Tensor
(
slice_size_
-
1
)});
auto
equal
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
EQUAL
),
gen_g
.
virtual_input_node
()
,
minimum
});
auto
equal
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
EQUAL
),
sub
,
minimum
});
auto
gather_v2
=
auto
gather_v2
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
GATHERV2
),
gen_g
.
virtual_input_node
(),
minimum
,
CreatInt32Imm
(
axis_
)});
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
GATHERV2
),
gen_g
.
virtual_input_node
(),
minimum
,
CreatInt32Imm
(
axis_
)});
auto
dtype
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
DTYPE
),
gather_v2
});
auto
dtype
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
DTYPE
),
gather_v2
});
...
@@ -250,8 +330,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
...
@@ -250,8 +330,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
Attr
attr_group
=
std
::
make_pair
(
GROUP
,
MakeValue
(
group_
.
name
()));
Attr
attr_group
=
std
::
make_pair
(
GROUP
,
MakeValue
(
group_
.
name
()));
OperatorAttrs
attrs
=
{
attr_op
,
attr_group
};
OperatorAttrs
attrs
=
{
attr_op
,
attr_group
};
auto
reduce_scatter
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
REDUCE_SCATTER
,
attrs
),
mul
});
auto
reduce_scatter
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
REDUCE_SCATTER
,
attrs
),
mul
});
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>
input_nodes
=
{
std
::
make_pair
(
sub
,
2
),
std
::
make_pair
(
gather_v2
,
1
),
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>
input_nodes
=
{
std
::
make_pair
(
sub
,
2
),
std
::
make_pair
(
gather_v2
,
1
)};
std
::
make_pair
(
equal
,
2
)};
replace_graph_
=
std
::
make_shared
<
std
::
pair
<
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>
,
AnfNodePtr
>>
(
replace_graph_
=
std
::
make_shared
<
std
::
pair
<
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>
,
AnfNodePtr
>>
(
std
::
make_pair
(
input_nodes
,
reduce_scatter
));
std
::
make_pair
(
input_nodes
,
reduce_scatter
));
...
@@ -309,11 +388,11 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
...
@@ -309,11 +388,11 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
Status
GatherV2PInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
Status
GatherV2PInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
is_auto_parallel_
=
true
;
is_auto_parallel_
=
true
;
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
Shapes
splittable_inputs
=
{
input0_split
};
Shape
input1_split
(
inputs_shape_
[
1
].
size
(),
1
);
Shapes
splittable_inputs
=
{
input0_split
,
input1_split
};
std
::
vector
<
StrategyPtr
>
sp_vector
;
std
::
vector
<
StrategyPtr
>
sp_vector
;
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
{
inputs_shape_
.
at
(
0
)},
splittable_inputs
,
&
sp_vector
)
!=
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
inputs_shape_
,
splittable_inputs
,
&
sp_vector
)
!=
SUCCESS
)
{
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Generate strategies for independent inputs() failed."
;
MS_LOG
(
ERROR
)
<<
name_
<<
" : Generate strategies for independent inputs() failed."
;
return
FAILED
;
return
FAILED
;
}
}
...
@@ -331,12 +410,13 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
...
@@ -331,12 +410,13 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GatherV2PInfo
::
GenerateBatchStrategies
()
{
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GatherV2PInfo
::
GenerateBatchStrategies
()
{
CheckGlobalDeviceManager
();
CheckGlobalDeviceManager
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
Dimensions
strategy
;
Dimensions
param_strategy
(
inputs_shape_
[
0
].
size
(),
1
);
strategy
.
push_back
(
SizeToInt
(
dev_num
));
Dimensions
index_strategy
;
for
(
size_t
i
=
1
;
i
<
inputs_shape_
[
0
].
size
();
i
++
)
{
index_strategy
.
push_back
(
SizeToInt
(
dev_num
));
strategy
.
push_back
(
1
);
for
(
size_t
i
=
1
;
i
<
inputs_shape_
[
1
].
size
();
i
++
)
{
index_strategy
.
push_back
(
1
);
}
}
std
::
vector
<
Dimensions
>
strategy_v
=
{
strategy
};
std
::
vector
<
Dimensions
>
strategy_v
=
{
param_strategy
,
index_
strategy
};
return
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
(
strategy_v
);
return
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
(
strategy_v
);
}
}
}
// namespace parallel
}
// namespace parallel
...
...
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
浏览文件 @
390a86ef
...
@@ -48,7 +48,7 @@ class GatherV2PInfo : public OperatorInfo {
...
@@ -48,7 +48,7 @@ class GatherV2PInfo : public OperatorInfo {
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
InferMirrorOps
()
override
{
return
SUCCESS
;
}
Status
InferMirrorOps
()
override
;
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
Status
InferTensorInfo
()
override
;
Status
InferTensorInfo
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferDevMatrixShape
()
override
;
...
...
tests/ut/python/parallel/test_gather_v2.py
浏览文件 @
390a86ef
...
@@ -61,7 +61,7 @@ class Net(nn.Cell):
...
@@ -61,7 +61,7 @@ class Net(nn.Cell):
def
test_gatherv2_semi_auto0
():
def
test_gatherv2_semi_auto0
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
1
,
8
),)
strategy1
=
((
1
,
8
),
(
1
,
1
)
)
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
@@ -73,7 +73,7 @@ def test_gatherv2_semi_auto0():
...
@@ -73,7 +73,7 @@ def test_gatherv2_semi_auto0():
def
test_gatherv2_semi_auto1
():
def
test_gatherv2_semi_auto1
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,
1
),)
strategy1
=
((
8
,
1
),
(
1
,
1
)
)
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
@@ -85,7 +85,7 @@ def test_gatherv2_semi_auto1():
...
@@ -85,7 +85,7 @@ def test_gatherv2_semi_auto1():
def
test_gatherv2_semi_auto2
():
def
test_gatherv2_semi_auto2
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
2
,
4
),)
strategy1
=
((
2
,
4
),
(
1
,
1
)
)
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
@@ -97,7 +97,7 @@ def test_gatherv2_semi_auto2():
...
@@ -97,7 +97,7 @@ def test_gatherv2_semi_auto2():
def
test_gatherv2_semi_auto3
():
def
test_gatherv2_semi_auto3
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
1
,
8
),)
strategy1
=
((
1
,
8
),
(
1
,
1
)
)
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
1
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
1
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
@@ -109,7 +109,7 @@ def test_gatherv2_semi_auto3():
...
@@ -109,7 +109,7 @@ def test_gatherv2_semi_auto3():
def
test_gatherv2_semi_auto4
():
def
test_gatherv2_semi_auto4
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,
1
),)
strategy1
=
((
8
,
1
),
(
1
,
1
)
)
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
1
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
1
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
@@ -121,7 +121,7 @@ def test_gatherv2_semi_auto4():
...
@@ -121,7 +121,7 @@ def test_gatherv2_semi_auto4():
def
test_gatherv2_semi_auto5
():
def
test_gatherv2_semi_auto5
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
2
,
4
),)
strategy1
=
((
2
,
4
),
(
1
,
1
)
)
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
strategy2
=
((
4
,
2
,
1
),
(
4
,
2
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
1
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
1
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
@@ -155,7 +155,7 @@ def test_gatherv2_semi_auto7():
...
@@ -155,7 +155,7 @@ def test_gatherv2_semi_auto7():
def
test_gatherv2_semi_auto8
():
def
test_gatherv2_semi_auto8
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
parallel_mode
=
"semi_auto_parallel"
)
strategy1
=
((
8
,),)
strategy1
=
((
8
,),
(
1
,
1
)
)
strategy2
=
((
4
,
2
),
(
4
,
2
))
strategy2
=
((
4
,
2
),
(
4
,
2
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
0
,
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
net
.
set_auto_parallel
()
...
...
tests/ut/python/parallel/test_gather_v2_primitive.py
浏览文件 @
390a86ef
...
@@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel():
...
@@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel():
def
test_axis1_batch_parallel
():
def
test_axis1_batch_parallel
():
gather_v2_strategy
=
((
device_number
,
1
),)
gather_v2_strategy
=
((
device_number
,
1
),
(
1
,
)
)
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
rank
=
2
rank
=
2
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
def
test_axis1_strategy1
():
def
test_axis1_strategy1
():
gather_v2_strategy
=
((
16
,
2
),)
gather_v2_strategy
=
((
16
,
2
),
(
1
,
)
)
rank
=
17
rank
=
17
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
net_trains
(
gather_v2_strategy
,
criterion
,
rank
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录