Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
faa1084b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
faa1084b
编写于
7月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2865 asymmetric param row split support for GatherV2
Merge pull request !2865 from yihuaijie/dev
上级
6798c548
cae254f4
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
186 addition
and
0 deletion
+186
-0
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
+118
-0
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
+7
-0
tests/ut/python/parallel/test_manual_gatherv2.py
tests/ut/python/parallel/test_manual_gatherv2.py
+61
-0
未找到文件。
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
浏览文件 @
faa1084b
...
...
@@ -20,6 +20,7 @@
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
#include "parallel/device_matrix.h"
#include "parallel/graph_util/generate_graph.h"
...
...
@@ -62,6 +63,55 @@ Status GatherV2PInfo::GetAttrs() {
return
FAILED
;
}
auto
manual_split_iter
=
attrs_
.
find
(
"manual_split"
);
if
(
manual_split_iter
!=
attrs_
.
end
())
{
param_split_shapes_
.
clear
();
manual_split_
=
true
;
auto
var
=
manual_split_iter
->
second
->
cast
<
ValueTuplePtr
>
();
MS_LOG
(
DEBUG
)
<<
"Extract manual split strategy "
<<
manual_split_iter
->
second
->
ToString
();
if
(
var
->
size
()
>
0
)
{
std
::
vector
<
ValuePtr
>
elements
=
var
->
value
();
for
(
auto
&
ele
:
elements
)
{
if
(
ele
->
isa
<
ValueSequeue
>
())
{
auto
value_tuple
=
ele
->
cast
<
ValueTuplePtr
>
();
std
::
vector
<
ValuePtr
>
value_vector
=
value_tuple
->
value
();
if
(
value_vector
.
size
()
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Failure: Size of manual_split element must be 2."
;
return
FAILED
;
}
param_split_shapes_
.
push_back
(
static_cast
<
int32_t
>
(
GetValue
<
int
>
(
value_vector
[
0
])));
index_offsets_
.
push_back
(
static_cast
<
int32_t
>
(
GetValue
<
int
>
(
value_vector
[
1
])));
}
else
{
MS_LOG
(
ERROR
)
<<
"Failure: Manual split strategy's format is wrong! Need ValueSequeue"
;
return
FAILED
;
}
}
if
(
param_split_shapes_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Failed to extract param split strategy."
;
return
FAILED
;
}
}
}
return
SUCCESS
;
}
Status
GatherV2PInfo
::
CheckManualSplit
()
{
auto
param_shape
=
inputs_shape_
.
at
(
0
);
int32_t
split_shape_sum
=
std
::
accumulate
(
param_split_shapes_
.
begin
(),
param_split_shapes_
.
end
(),
0
,
[](
int32_t
s
,
int32_t
shape
)
{
return
s
+
shape
;
});
if
(
split_shape_sum
<
param_shape
.
at
(
0
))
{
MS_LOG
(
ERROR
)
<<
"Failure: Sum of splited shapes should not be smaller than param_shape."
;
return
FAILED
;
}
if
(
std
::
any_of
(
index_offsets_
.
begin
(),
index_offsets_
.
end
(),
[](
const
int32_t
&
offset
)
{
return
offset
<
0
;
}))
{
MS_LOG
(
ERROR
)
<<
"Failure: Index offset must not less than 0."
;
return
FAILED
;
}
return
SUCCESS
;
}
...
...
@@ -103,6 +153,14 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return
FAILED
;
}
if
(
manual_split_
)
{
if
(
CheckManualSplit
()
!=
SUCCESS
)
{
return
FAILED
;
}
// when using manual_split, no need to check belowings.
return
SUCCESS
;
}
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
if
(
axis_
!=
0
&&
param_shape
.
at
(
0
)
%
(
param_strategy
.
at
(
0
)
*
param_strategy
.
at
(
IntToSize
(
axis_
)))
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
name_
<<
": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."
;
...
...
@@ -130,6 +188,11 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
}
Status
GatherV2PInfo
::
InferMirrorOps
()
{
// There is no mirror operators for manual split
if
(
manual_split_
)
{
return
SUCCESS
;
}
mirror_ops_
.
clear
();
Shape
input_a_tensor_map
=
inputs_tensor_map_
.
at
(
0
);
std
::
vector
<
Group
>
input_a_group
;
...
...
@@ -160,6 +223,13 @@ Status GatherV2PInfo::InferDevMatrixShape() {
// infer input dev_matrix_shape
auto
param_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
auto
index_strategy
=
strategy_
->
GetInputDim
().
at
(
1
);
if
(
manual_split_
)
{
dev_matrix_shape_
=
param_strategy
;
out_dev_matrix_shape_
=
dev_matrix_shape_
;
return
SUCCESS
;
}
dev_matrix_shape_
=
param_strategy
;
// param_strategy(axis)!=1,
...
...
@@ -195,6 +265,12 @@ Status GatherV2PInfo::InferDevMatrixShape() {
}
Status
GatherV2PInfo
::
InferTensorMap
()
{
if
(
manual_split_
)
{
inputs_tensor_map_
.
push_back
({
1
,
0
});
inputs_tensor_map_
.
push_back
({
-
1
,
1
});
outputs_tensor_map_
.
push_back
({
-
1
,
1
,
0
});
return
SUCCESS
;
}
// infer input tensor map
// param_strategy(axis) != 1
size_t
param_size
=
inputs_shape_
.
at
(
0
).
size
();
...
...
@@ -261,8 +337,13 @@ Status GatherV2PInfo::InferTensorInfo() {
Shape
input_shape
=
inputs_shape_
.
at
(
0
);
Shape
input_index_shape
=
inputs_shape_
.
at
(
1
);
Shape
output_shape
=
outputs_shape_
.
at
(
0
);
int32_t
rank
=
g_device_manager
->
global_rank
();
// infer tensor layout
TensorLayout
input_tensor_layout
,
input_index_layout
,
output_tensor_layout
;
if
(
manual_split_
)
{
input_shape
[
0
]
=
param_split_shapes_
[
rank
/
dev_matrix_shape_
[
1
]];
input_shape
[
0
]
=
input_shape
[
0
]
*
dev_matrix_shape_
[
0
];
}
if
((
input_tensor_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
.
at
(
0
),
input_shape
)
!=
SUCCESS
)
||
(
input_index_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
.
at
(
1
),
input_index_shape
)
!=
SUCCESS
)
||
(
output_tensor_layout
.
InitFromVector
(
out_dev_matrix_shape_
,
outputs_tensor_map_
.
at
(
0
),
output_shape
)
!=
...
...
@@ -274,6 +355,9 @@ Status GatherV2PInfo::InferTensorInfo() {
TensorInfo
input_index_info
(
input_index_layout
);
TensorInfo
output_tensor_info
(
output_tensor_layout
);
Shape
slice_shape
=
input_tensor_info
.
slice_shape
();
MS_LOG
(
DEBUG
)
<<
"The fake slice shape is: "
<<
ShapeToString
(
slice_shape
);
inputs_tensor_info_
.
push_back
(
input_tensor_info
);
inputs_tensor_info_
.
push_back
(
input_index_info
);
outputs_tensor_info_
.
push_back
(
output_tensor_info
);
...
...
@@ -312,6 +396,19 @@ Status GatherV2PInfo::InferBias() {
return
FAILED
;
}
Status
GatherV2PInfo
::
InferOffset
()
{
CheckGlobalDeviceManager
();
size_t
rank
=
g_device_manager
->
global_rank
();
if
(
rank
<
index_offsets_
.
size
())
{
index_offset_
=
index_offsets_
.
at
(
rank
);
MS_LOG
(
DEBUG
)
<<
name_
<<
": Device rank "
<<
rank
<<
", Index Offset: "
<<
index_offset_
;
return
SUCCESS
;
}
MS_LOG
(
ERROR
)
<<
name_
<<
": Get index offset failed, index offset size is"
<<
index_offsets_
.
size
();
return
FAILED
;
}
Status
GatherV2PInfo
::
InferGroup
()
{
auto
param_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
size_t
dim
=
IntToSize
(
axis_
);
...
...
@@ -410,6 +507,19 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
MS_LOG
(
ERROR
)
<<
"GenerateGraph Init failed"
;
return
FAILED
;
}
if
(
manual_split_
)
{
if
(
InferOffset
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer Bias failed."
;
return
FAILED
;
}
auto
sub
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
SUB
),
gen_g
.
virtual_input_node
(),
CreateInt32Tensor
(
index_offset_
)});
auto
gather_v2
=
gen_g
.
PushBack
({
gen_g
.
NewOpInst
(
replace_op_name_
),
gen_g
.
virtual_input_node
(),
sub
,
CreatInt32Imm
(
axis_
)});
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>
input_nodes
=
{
std
::
make_pair
(
sub
,
2
),
std
::
make_pair
(
gather_v2
,
1
)};
replace_graph_
=
std
::
make_shared
<
std
::
pair
<
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
int
>>
,
AnfNodePtr
>>
(
std
::
make_pair
(
input_nodes
,
gather_v2
));
return
SUCCESS
;
}
if
(
InferBias
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer Bias failed."
;
return
FAILED
;
...
...
@@ -444,6 +554,14 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
}
ReplaceGraphPtr
GatherV2PInfo
::
replace_graph
(
const
CNodePtr
&
cnode
)
{
if
(
manual_split_
)
{
if
(
ComputeReplaceGraph
(
cnode
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": ComputeReplaceGraph failed."
;
return
nullptr
;
}
return
replace_graph_
;
}
auto
param_strategy
=
strategy_
->
GetInputDim
().
at
(
0
);
// target_ == CPU, no need to raplace graph
if
(
target_
==
CPU
)
{
...
...
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
浏览文件 @
faa1084b
...
...
@@ -36,6 +36,7 @@ class GatherV2PInfo : public OperatorInfo {
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
GatherV2PCost
>
()),
axis_
(
0
),
bias_
(
0
),
index_offset_
(
0
),
slice_size_
(
0
)
{}
~
GatherV2PInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
@@ -57,20 +58,26 @@ class GatherV2PInfo : public OperatorInfo {
private:
Status
ComputeReplaceGraph
(
const
CNodePtr
&
cnode
);
Status
CheckManualSplit
();
Status
ComputeReplaceOp
();
Status
InferBias
();
Status
InferOffset
();
Status
InferGroup
();
int32_t
axis_
;
std
::
string
target_
;
std
::
string
replace_op_name_
=
GATHERV2
;
int32_t
bias_
;
int32_t
index_offset_
;
int32_t
slice_size_
;
Shape
out_dev_matrix_shape_
;
Group
group_
;
bool
reduce_scatter_flag_
=
false
;
int32_t
split_num_
=
1
;
bool
host_reduce_scatter_
=
false
;
bool
manual_split_
=
false
;
std
::
vector
<
int32_t
>
param_split_shapes_
;
std
::
vector
<
int32_t
>
index_offsets_
;
};
class
SparseGatherV2Info
:
public
GatherV2PInfo
{
...
...
tests/ut/python/parallel/test_manual_gatherv2.py
0 → 100644
浏览文件 @
faa1084b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.common.initializer
import
initializer
class
Net
(
Cell
):
def
__init__
(
self
,
strategy1
=
None
,
strategy2
=
None
,
strategy3
=
None
):
super
().
__init__
()
self
.
gatherv2
=
P
.
GatherV2
().
set_strategy
(
strategy1
)
self
.
gatherv2
.
add_prim_attr
(
"manual_split"
,
((
1
,
0
),
(
7
,
1
)))
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
reshape
=
P
.
Reshape
()
self
.
matmul
=
P
.
MatMul
().
set_strategy
(
strategy3
)
self
.
matmul
.
add_prim_attr
(
"forward_reduce_scatter"
,
True
)
self
.
param
=
Parameter
(
initializer
(
"ones"
,
(
8
,
64
),
ms
.
float32
),
name
=
"gatherv2_param"
)
self
.
mul_weight
=
Parameter
(
initializer
(
"ones"
,
(
2
,
4
,
64
),
ms
.
float32
),
name
=
"mul_weight"
)
self
.
matmul_weight
=
Parameter
(
initializer
(
"ones"
,
(
256
,
16
),
ms
.
float32
),
name
=
"matmul_weight"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
gatherv2
(
self
.
param
,
x
,
0
)
out
=
self
.
mul
(
out
,
self
.
mul_weight
)
out
=
self
.
reshape
(
out
,
(
2
,
256
))
out
=
self
.
matmul
(
out
,
self
.
matmul_weight
)
return
out
_x
=
Tensor
(
np
.
ones
([
2
,
4
]),
dtype
=
ms
.
int32
)
_b
=
Tensor
(
np
.
ones
([
64
,
8
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_neg_data_parallel
():
context
.
set_context
(
save_graphs
=
True
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
2
,
global_rank
=
0
)
strategy1
=
((
2
,
1
),
(
1
,
2
))
strategy2
=
((
1
,
2
,
1
),
(
1
,
2
,
1
))
strategy3
=
((
1
,
2
),
(
2
,
1
))
net
=
Net
(
strategy1
,
strategy2
,
strategy3
)
compile_net
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录