Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4a0e6ff7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4a0e6ff7
编写于
8月 12, 2020
作者:
Y
yangzhenzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update field split
上级
a7556d87
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
339 addition
and
73 deletion
+339
-73
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
...e/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
+7
-6
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
...pore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
+175
-51
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h
...spore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h
+3
-1
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h
...ore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h
+5
-0
mindspore/parallel/_tensor.py
mindspore/parallel/_tensor.py
+5
-2
tests/ut/python/parallel/test_get_parameter_layout.py
tests/ut/python/parallel/test_get_parameter_layout.py
+2
-2
tests/ut/python/parallel/test_manual_gatherv2.py
tests/ut/python/parallel/test_manual_gatherv2.py
+142
-11
未找到文件。
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
浏览文件 @
4a0e6ff7
...
...
@@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto
device_arrangement
=
tensor_layout
->
device_arrangement
().
array
();
auto
tensor_map
=
tensor_layout
->
tensor_map
().
array
();
auto
slice_shape
=
tensor_layout
->
slice_shape
().
array
();
int32_t
_field_size
=
tensor_layout
->
get_field_size
()
;
Shape
field_size
;
if
(
_field_size
!=
0
)
{
field_size
.
push_back
(
_field_size
);
Shape
field_size
=
{
tensor_layout
->
get_field_size
()}
;
Shape
uniform_split
;
if
(
tensor_layout
->
uniform_split
()
)
{
uniform_split
.
push_back
(
1
);
}
else
{
field_size
=
{
0
}
;
uniform_split
.
push_back
(
0
)
;
}
std
::
vector
<
Shape
>
layout
=
{
device_arrangement
,
tensor_map
,
slice_shape
,
field_size
};
std
::
vector
<
Shape
>
layout
=
{
device_arrangement
,
tensor_map
,
slice_shape
,
field_size
,
uniform_split
};
dict
[
py
::
str
(
name
)]
=
layout
;
MS_LOG
(
INFO
)
<<
"GetParameterLayout name = "
<<
name
<<
", layout "
<<
tensor_layout
->
ToString
();
}
...
...
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
浏览文件 @
4a0e6ff7
...
...
@@ -27,6 +27,92 @@
namespace
mindspore
{
namespace
parallel
{
Status
GatherV2PInfo
::
GetManualSplitWithoutOffsetAttr
()
{
auto
manual_split_without_offset_iter
=
attrs_
.
find
(
"manual_split"
);
if
(
manual_split_without_offset_iter
!=
attrs_
.
end
())
{
manual_split_
=
true
;
MS_EXCEPTION_IF_NULL
(
manual_split_without_offset_iter
->
second
);
if
(
manual_split_without_offset_iter
->
second
->
cast
<
ValueTuplePtr
>
()
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Manual split without offset strategy's format is wrong! Need ValueSequeue"
;
return
FAILED
;
}
std
::
vector
<
ValuePtr
>
value_vector
=
manual_split_without_offset_iter
->
second
->
cast
<
ValueTuplePtr
>
()
->
value
();
MS_LOG
(
INFO
)
<<
name_
<<
": manual split with offset is "
<<
manual_split_without_offset_iter
->
second
->
ToString
();
int64_t
offset
=
0
;
for
(
auto
&
ele
:
value_vector
)
{
index_offsets_
.
push_back
(
offset
);
if
(
!
ele
->
isa
<
Int32Imm
>
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The element of manual split must be int"
;
return
FAILED
;
}
int64_t
param_split_shape
=
static_cast
<
int64_t
>
(
GetValue
<
int
>
(
ele
));
if
(
param_split_shape
<=
0
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The value of manual split must be positive, but got "
<<
param_split_shape
;
return
FAILED
;
}
param_split_shapes_
.
push_back
(
param_split_shape
);
offset
+=
param_split_shape
;
}
if
(
param_split_shapes_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Failed to extract param split's split info"
;
return
FAILED
;
}
}
return
SUCCESS
;
}
Status
GatherV2PInfo
::
GetManualSplitAttr
()
{
auto
manual_split_with_offset_iter
=
attrs_
.
find
(
"manual_split_with_offset"
);
if
(
manual_split_with_offset_iter
!=
attrs_
.
end
())
{
manual_split_
=
true
;
auto
var
=
manual_split_with_offset_iter
->
second
->
cast
<
ValueTuplePtr
>
();
if
(
var
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Manual split with offset strategy's format is wrong! Need ValueSequeue"
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": manual split with offset strategy "
<<
var
->
ToString
();
for
(
auto
&
ele
:
var
->
value
())
{
if
(
!
ele
->
isa
<
ValueSequeue
>
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Manual split with offset strategy's format is wrong! Need ValueSequeue"
;
return
FAILED
;
}
std
::
vector
<
ValuePtr
>
value_vector
=
ele
->
cast
<
ValueTuplePtr
>
()
->
value
();
if
(
value_vector
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Size of manual split with offset's element must be 2"
;
return
FAILED
;
}
int64_t
param_split_row
=
static_cast
<
int64_t
>
(
GetValue
<
int
>
(
value_vector
[
0
]));
int64_t
offset
=
static_cast
<
int64_t
>
(
GetValue
<
int
>
(
value_vector
[
1
]));
if
((
param_split_row
<=
0
)
||
(
offset
<
0
))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The value of param split shape must be positive, and the offset must larger or equal to 0"
;
return
FAILED
;
}
param_split_shapes_
.
push_back
(
param_split_row
);
index_offsets_
.
push_back
(
offset
);
}
if
(
param_split_shapes_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Failed to extract param split with offset's split info"
;
return
FAILED
;
}
if
(
std
::
any_of
(
index_offsets_
.
begin
(),
index_offsets_
.
end
(),
[](
const
int64_t
&
offset
)
{
return
offset
<
0
;
}))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Index offset must not less than 0"
;
return
FAILED
;
}
return
SUCCESS
;
}
if
(
GetManualSplitWithoutOffsetAttr
()
!=
SUCCESS
)
{
return
FAILED
;
}
return
SUCCESS
;
}
Status
GatherV2PInfo
::
GetAttrs
()
{
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if
(
target_
!=
CPU
)
{
...
...
@@ -53,58 +139,76 @@ Status GatherV2PInfo::GetAttrs() {
if
(
target_iter
->
second
->
isa
<
StringImm
>
())
{
target_
=
target_iter
->
second
->
cast
<
StringImmPtr
>
()
->
value
();
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
"
: The value of target is not a string."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": The value of target is not a string."
;
}
}
auto
manual_split_iter
=
attrs_
.
find
(
"manual_split"
);
if
(
manual_split_iter
!=
attrs_
.
end
())
{
param_split_shapes_
.
clear
();
manual_split_
=
true
;
auto
var
=
manual_split_iter
->
second
->
cast
<
ValueTuplePtr
>
();
MS_LOG
(
DEBUG
)
<<
"Extract manual split strategy "
<<
manual_split_iter
->
second
->
ToString
();
if
(
var
->
size
()
>
0
)
{
std
::
vector
<
ValuePtr
>
elements
=
var
->
value
();
for
(
auto
&
ele
:
elements
)
{
if
(
ele
->
isa
<
ValueSequeue
>
())
{
auto
value_tuple
=
ele
->
cast
<
ValueTuplePtr
>
();
std
::
vector
<
ValuePtr
>
value_vector
=
value_tuple
->
value
();
if
(
value_vector
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Failure: Size of manual_split element must be 2."
;
return
FAILED
;
}
param_split_shapes_
.
push_back
(
static_cast
<
int64_t
>
(
GetValue
<
int
>
(
value_vector
[
0
])));
index_offsets_
.
push_back
(
static_cast
<
int64_t
>
(
GetValue
<
int
>
(
value_vector
[
1
])));
}
else
{
MS_LOG
(
ERROR
)
<<
"Failure: Manual split strategy's format is wrong! Need ValueSequeue"
;
return
FAILED
;
}
}
if
(
param_split_shapes_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Failed to extract param split strategy."
;
return
FAILED
;
}
}
if
(
GetManualSplitAttr
()
!=
SUCCESS
)
{
return
FAILED
;
}
if
(
manual_split_
&&
(
axis_
!=
0
))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The axis or offset must be 0 if manual split, bug got "
<<
axis_
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
GatherV2PInfo
::
CheckManualSplit
()
{
auto
param_shape
=
inputs_shape_
.
at
(
0
);
int64_t
split_shape_sum
=
std
::
accumulate
(
param_split_shapes_
.
begin
(),
param_split_shapes_
.
end
(),
0
,
[](
int64_t
s
,
int64_t
shape
)
{
return
s
+
shape
;
});
if
(
split_shape_sum
<
param_shape
.
at
(
0
))
{
MS_LOG
(
ERROR
)
<<
"Failure: Sum of splited shapes should not be smaller than param_shape."
;
Status
GatherV2PInfo
::
CheckManualSplit
(
const
Strategys
&
strategy
)
{
if
(
strategy
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of strategy must be 2, but got "
<<
strategy
.
size
();
return
FAILED
;
}
Dimensions
param_strategy
=
strategy
[
0
];
Dimensions
indices_strategy
=
strategy
[
1
];
if
(
param_strategy
.
size
()
!=
2
||
indices_strategy
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of param strategy or indices strategy must be 2"
;
return
FAILED
;
}
if
(
indices_strategy
[
0
]
!=
1
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The indices_strategy[0] must be 1, bug got "
<<
indices_strategy
[
0
];
return
FAILED
;
}
if
(
param_strategy
[
0
]
!=
indices_strategy
[
1
])
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The param_strategy[0] must be equal to indices_strategy[1]"
;
return
FAILED
;
}
if
(
indices_strategy
[
1
]
!=
SizeToInt
(
param_split_shapes_
.
size
()))
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The indices_strategy[1] must be equal to manual split size"
;
return
FAILED
;
}
int64_t
min_param_slice_row
=
inputs_shape_
[
1
][
1
]
/
indices_strategy
[
1
];
bool
invalid
=
std
::
any_of
(
param_split_shapes_
.
begin
(),
param_split_shapes_
.
end
(),
[
&
min_param_slice_row
](
int64_t
v
)
{
return
v
<
min_param_slice_row
;
});
if
(
invalid
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The split value must be larger than or equal to indices slice's column num"
;
return
FAILED
;
}
if
(
inputs_shape_
[
0
][
0
]
<
inputs_shape_
[
1
][
1
])
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The param's row smaller than indices' column"
;
return
FAILED
;
}
if
(
std
::
any_of
(
index_offsets_
.
begin
(),
index_offsets_
.
end
(),
[](
const
int64_t
&
offset
)
{
return
offset
<
0
;
}))
{
MS_LOG
(
ERROR
)
<<
"Failure: Index offset must not less than 0."
;
// Don't support repeated calc
CheckGlobalDeviceManager
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
auto
product_p
=
std
::
accumulate
(
param_strategy
.
begin
(),
param_strategy
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
IntToSize
(
product_p
)
<
dev_num
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Manual split doesn't support repeated calc"
;
return
FAILED
;
}
int64_t
split_shape_sum
=
std
::
accumulate
(
param_split_shapes_
.
begin
(),
param_split_shapes_
.
end
(),
0
,
[](
int64_t
s
,
int64_t
shape
)
{
return
s
+
shape
;
});
if
(
split_shape_sum
!=
inputs_shape_
[
0
][
0
])
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Sum of splited shapes must be equal to param_shape[0]"
;
return
FAILED
;
}
return
SUCCESS
;
}
...
...
@@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
}
if
(
manual_split_
)
{
if
(
CheckManualSplit
()
!=
SUCCESS
)
{
if
(
CheckManualSplit
(
strategy
->
GetInputDim
()
)
!=
SUCCESS
)
{
return
FAILED
;
}
// when using manual_split, no need to check belowings.
...
...
@@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() {
SUCCESS
))
{
return
FAILED
;
}
if
(
manual_split_
)
{
input_tensor_layout
.
set_uniform_split
(
false
);
}
// infer tensor info
TensorInfo
input_tensor_info
(
input_tensor_layout
);
TensorInfo
input_index_info
(
input_index_layout
);
TensorInfo
output_tensor_info
(
output_tensor_layout
);
Shape
slice_shape
=
input_tensor_info
.
slice_shape
();
MS_LOG
(
DEBUG
)
<<
"The fake slice shape is: "
<<
ShapeToString
(
slice_shape
);
inputs_tensor_info_
.
push_back
(
input_tensor_info
);
inputs_tensor_info_
.
push_back
(
input_index_info
);
outputs_tensor_info_
.
push_back
(
output_tensor_info
);
...
...
@@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() {
Status
GatherV2PInfo
::
InferOffset
()
{
CheckGlobalDeviceManager
();
size_t
rank
=
g_device_manager
->
global_rank
();
if
(
rank
<
index_offsets_
.
size
())
{
index_offset_
=
index_offsets_
.
at
(
rank
);
MS_LOG
(
DEBUG
)
<<
name_
<<
": Device rank "
<<
rank
<<
", Index Offset: "
<<
index_offset_
;
MS_EXCEPTION_IF_NULL
(
strategy_
);
auto
param_strategy
=
strategy_
->
GetInputDim
()[
0
];
if
(
param_strategy
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"The size of param strategy must be 2"
;
return
FAILED
;
}
size_t
index
=
rank
/
param_strategy
[
1
];
if
(
index
<
index_offsets_
.
size
())
{
index_offset_
=
index_offsets_
[
index
];
MS_LOG
(
INFO
)
<<
name_
<<
": Device rank "
<<
rank
<<
", Index Offset: "
<<
index_offset_
;
return
SUCCESS
;
}
...
...
@@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
ReplaceGraphPtr
GatherV2PInfo
::
replace_graph
(
const
CNodePtr
&
cnode
)
{
if
(
manual_split_
&&
target_
!=
CPU
)
{
if
(
ComputeReplaceGraph
(
cnode
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": ComputeReplaceGraph failed."
;
return
nullptr
;
MS_LOG
(
EXCEPTION
)
<<
name_
<<
": ComputeReplaceGraph failed."
;
}
return
replace_graph_
;
}
...
...
@@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
return
nullptr
;
}
if
(
param_strategy
.
at
(
IntToSize
(
axis_
))
!=
1
&&
ComputeReplaceGraph
(
cnode
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": ComputeReplaceGraph failed."
;
return
nullptr
;
MS_LOG
(
EXCEPTION
)
<<
name_
<<
": ComputeReplaceGraph failed."
;
}
return
replace_graph_
;
}
...
...
@@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
}
Status
GatherV2PInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
return
FAILED
;
}
if
(
manual_split_
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Manual split does not support to search strategy"
;
return
FAILED
;
}
is_auto_parallel_
=
true
;
Shape
input0_split
(
inputs_shape_
[
0
].
size
(),
1
);
Shape
input1_split
(
inputs_shape_
[
1
].
size
(),
1
);
...
...
@@ -621,14 +739,14 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
std
::
vector
<
StrategyPtr
>
sp_vector
;
if
(
GenerateStrategiesForIndependentInputs
(
stage_id
,
inputs_shape_
,
splittable_inputs
,
&
sp_vector
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
"
: Generate strategies for independent inputs() failed."
;
MS_LOG
(
ERROR
)
<<
name_
<<
": Generate strategies for independent inputs() failed."
;
return
FAILED
;
}
size_t
success
=
0
;
for
(
auto
&
sp
:
sp_vector
)
{
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
success
++
;
MS_LOG
(
INFO
)
<<
name_
<<
"
: Successfully generated "
<<
success
<<
" strategy"
;
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated "
<<
success
<<
" strategy"
;
PrintStrategy
(
sp
);
}
}
...
...
@@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
}
std
::
shared_ptr
<
Strategys
>
GatherV2PInfo
::
GenerateBatchStrategies
()
{
if
(
GetAttrs
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
name_
<<
": Get attr failed"
;
}
if
(
manual_split_
)
{
MS_LOG
(
EXCEPTION
)
<<
name_
<<
": Manual split does not support to generate batch strategy"
;
}
CheckGlobalDeviceManager
();
size_t
dev_num
=
g_device_manager
->
GetDeviceListByStageId
(
0
).
size
();
Dimensions
param_strategy
(
inputs_shape_
[
0
].
size
(),
1
);
...
...
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h
浏览文件 @
4a0e6ff7
...
...
@@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
Status
GetAttrs
()
override
;
Status
ComputeReplaceGraph
(
const
CNodePtr
&
cnode
);
Status
CheckManualSplit
();
Status
CheckManualSplit
(
const
Strategys
&
strategy
);
Status
GetManualSplitAttr
();
Status
GetManualSplitWithoutOffsetAttr
();
Status
ComputeReplaceOp
();
Status
InferBias
();
Status
InferOffset
();
...
...
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h
浏览文件 @
4a0e6ff7
...
...
@@ -48,6 +48,10 @@ class TensorLayout {
void
set_field_size
(
int32_t
field_size
)
{
field_size_
=
field_size
;
}
bool
uniform_split
()
const
{
return
uniform_split_
;
}
void
set_uniform_split
(
bool
flag
)
{
uniform_split_
=
flag
;
}
Arrangement
device_arrangement
()
const
{
return
device_arrangement_
;
}
Map
tensor_map
()
const
{
return
tensor_map_
;
}
...
...
@@ -104,6 +108,7 @@ class TensorLayout {
Arrangement
tensor_shape_
;
bool
skip_redistribution_
=
false
;
int32_t
field_size_
=
0
;
bool
uniform_split_
=
true
;
};
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/parallel/_tensor.py
浏览文件 @
4a0e6ff7
...
...
@@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
"""
if
not
isinstance
(
layout
,
list
):
raise
TypeError
(
"The layout should be list! layout is {}"
.
format
(
layout
))
if
len
(
layout
)
<
3
:
raise
ValueError
(
"The length of layout must be larger than
3
! layout is {}"
.
format
(
layout
))
if
len
(
layout
)
<
5
:
raise
ValueError
(
"The length of layout must be larger than
5
! layout is {}"
.
format
(
layout
))
dev_mat
=
layout
[
0
]
tensor_map
=
layout
[
1
]
uniform_split
=
layout
[
4
]
if
uniform_split
[
0
]
==
0
:
raise
RuntimeError
(
"The load tensor only support uniform split now"
)
if
tensor
.
size
()
==
1
:
return
tensor
return
_load_tensor
(
tensor
,
dev_mat
,
tensor_map
)
...
...
tests/ut/python/parallel/test_get_parameter_layout.py
浏览文件 @
4a0e6ff7
...
...
@@ -49,8 +49,8 @@ def test_get_parameter_layout():
net
.
set_auto_parallel
()
exe
=
me
.
_executor
exe
.
compile
(
net
,
x
,
phase
=
'train'
,
auto_parallel_mode
=
True
)
x_layout
=
[[
2
,
4
],
[
1
,
-
1
],
[
16
,
32
],
[
0
]]
# device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout
=
[[
2
,
4
],
[
0
,
-
1
],
[
16
,
32
],
[
0
]]
# device_arrangement = [2, 4], tensor_map = [0, -1]
x_layout
=
[[
2
,
4
],
[
1
,
-
1
],
[
16
,
32
],
[
0
]
,
[
1
]
]
# device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout
=
[[
2
,
4
],
[
0
,
-
1
],
[
16
,
32
],
[
0
]
,
[
1
]
]
# device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict
=
{
'x'
:
x_layout
,
'w1'
:
weight_layout
}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert
net
.
parameter_layout_dict
==
expect_dict
...
...
tests/ut/python/parallel/test_manual_gatherv2.py
浏览文件 @
4a0e6ff7
...
...
@@ -14,6 +14,7 @@
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
...
...
@@ -22,40 +23,170 @@ from mindspore.ops import operations as P
from
mindspore.common.initializer
import
initializer
class
Net
(
Cell
):
def
__init__
(
self
,
strategy1
=
None
,
strategy2
=
None
,
strategy3
=
None
):
def
__init__
(
self
,
strategy1
=
None
,
strategy2
=
None
,
strategy3
=
None
,
axis
=
0
,
init_flag
=
True
,
split_tuple
=
(
4
,
4
),
split_string
=
"manual_split"
,
param_shape
=
(
8
,
8
)):
super
().
__init__
()
self
.
gatherv2
=
P
.
GatherV2
().
set_strategy
(
strategy1
)
self
.
gatherv2
.
add_prim_attr
(
"manual_split"
,
((
1
,
0
),
(
7
,
1
))
)
self
.
gatherv2
.
add_prim_attr
(
split_string
,
split_tuple
)
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
reshape
=
P
.
Reshape
()
self
.
matmul
=
P
.
MatMul
().
set_strategy
(
strategy3
)
self
.
matmul
.
add_prim_attr
(
"forward_reduce_scatter"
,
True
)
self
.
param
=
Parameter
(
initializer
(
"ones"
,
(
8
,
64
),
ms
.
float32
),
name
=
"gatherv2_param"
)
self
.
mul_weight
=
Parameter
(
initializer
(
"ones"
,
(
2
,
4
,
64
),
ms
.
float32
),
name
=
"mul_weight"
)
self
.
matmul_weight
=
Parameter
(
initializer
(
"ones"
,
(
256
,
16
),
ms
.
float32
),
name
=
"matmul_weight"
)
if
init_flag
:
self
.
param
=
Parameter
(
initializer
(
"ones"
,
param_shape
,
ms
.
float32
),
name
=
"gatherv2_param"
)
else
:
self
.
param
=
Parameter
(
Tensor
(
np
.
ones
(
param_shape
),
dtype
=
ms
.
float32
),
name
=
"gatherv2_param"
)
self
.
mul_weight
=
Parameter
(
initializer
(
"ones"
,
(
8
,
8
,
8
),
ms
.
float32
),
name
=
"mul_weight"
)
self
.
matmul_weight
=
Parameter
(
initializer
(
"ones"
,
(
64
,
16
),
ms
.
float32
),
name
=
"matmul_weight"
)
self
.
axis
=
axis
def
construct
(
self
,
x
,
b
):
out
=
self
.
gatherv2
(
self
.
param
,
x
,
0
)
out
=
self
.
gatherv2
(
self
.
param
,
x
,
self
.
axis
)
out
=
self
.
mul
(
out
,
self
.
mul_weight
)
out
=
self
.
reshape
(
out
,
(
2
,
256
))
out
=
self
.
reshape
(
out
,
(
8
,
64
))
out
=
self
.
matmul
(
out
,
self
.
matmul_weight
)
return
out
_x
=
Tensor
(
np
.
ones
([
2
,
4
]),
dtype
=
ms
.
int32
)
_x
=
Tensor
(
np
.
ones
([
8
,
8
]),
dtype
=
ms
.
int32
)
_b
=
Tensor
(
np
.
ones
([
64
,
8
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
context
.
set_context
(
save_graphs
=
True
)
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
_executor
.
compile
(
train_net
,
_x
,
_b
,
auto_parallel_mode
=
True
)
context
.
reset_auto_parallel_context
()
def
test_neg_data_parallel
():
context
.
set_context
(
save_graphs
=
True
)
def
test_normal_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
compile_net
(
net
)
def
test_normal_split2
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
4
,
global_rank
=
0
)
strategy1
=
((
4
,
1
),
(
1
,
4
))
strategy2
=
((
1
,
4
,
1
),
(
1
,
4
,
1
))
strategy3
=
((
1
,
4
),
(
4
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
,
split_tuple
=
(
10
,
20
,
30
,
4
),
param_shape
=
(
64
,
8
))
compile_net
(
net
)
def
test_normal_split3
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
32
,
global_rank
=
17
)
strategy1
=
((
4
,
8
),
(
1
,
4
))
strategy2
=
((
1
,
4
,
8
),
(
1
,
4
,
8
))
strategy3
=
((
1
,
32
),
(
32
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
,
split_tuple
=
(
10
,
20
,
30
,
4
),
param_shape
=
(
64
,
8
))
compile_net
(
net
)
def
test_normal_split_with_offset
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
,
split_string
=
"manual_split_with_offset"
,
split_tuple
=
((
4
,
0
),
(
4
,
4
)))
compile_net
(
net
)
def
test_auto_parallel_error
():
context
.
set_context
(
save_graphs
=
True
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
net
=
Net
()
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_axis_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
,
axis
=
1
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_strategy_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
4
,
1
),
(
8
,
1
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_strategy_error2
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
4
,
1
),
(
1
,
8
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_strategy_error3
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_strategy_error4
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
8
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_strategy_error5
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
4
,
global_rank
=
0
)
strategy1
=
((
4
,
1
),
(
1
,
4
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_split_tuple_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
,
split_tuple
=
((
5
,
0
),
(
5
,
5
)))
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
def
test_parameter_use_tensor_error
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
,
init_flag
=
False
)
with
pytest
.
raises
(
RuntimeError
):
compile_net
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录