Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_44025039
mindspore
提交
f4bb43bb
M
mindspore
项目概览
weixin_44025039
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f4bb43bb
编写于
8月 06, 2020
作者:
Y
yangzhenzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add concat op
上级
a3959071
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
585 addition
and
21 deletion
+585
-21
mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h
...csrc/frontend/parallel/auto_parallel/operator_costmodel.h
+2
-0
mindspore/ccsrc/frontend/parallel/dynamic_creator.h
mindspore/ccsrc/frontend/parallel/dynamic_creator.h
+1
-0
mindspore/ccsrc/frontend/parallel/node_check.cc
mindspore/ccsrc/frontend/parallel/node_check.cc
+0
-1
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
+268
-0
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h
+62
-0
mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h
...re/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h
+1
-0
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
+9
-2
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+101
-15
mindspore/ccsrc/frontend/parallel/step_parallel.h
mindspore/ccsrc/frontend/parallel/step_parallel.h
+2
-0
mindspore/ops/_grad/grad_comm_ops.py
mindspore/ops/_grad/grad_comm_ops.py
+11
-3
tests/ut/python/parallel/test_concat.py
tests/ut/python/parallel/test_concat.py
+128
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h
浏览文件 @
f4bb43bb
...
...
@@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost {
using
SoftmaxCostPtr
=
std
::
shared_ptr
<
SoftmaxCost
>
;
using
TileCost
=
SoftmaxCost
;
using
TileCostPtr
=
std
::
shared_ptr
<
TileCost
>
;
using
ConcatCost
=
TileCost
;
using
ConcatCostPtr
=
std
::
shared_ptr
<
ConcatCost
>
;
class
TmpIdentityCost
:
public
OperatorCost
{
public:
...
...
mindspore/ccsrc/frontend/parallel/dynamic_creator.h
浏览文件 @
f4bb43bb
...
...
@@ -136,6 +136,7 @@ REGISTER(EmbeddingLookupInfo);
REGISTER
(
TileInfo
);
REGISTER
(
StridedSliceInfo
);
REGISTER
(
DropoutInfo
);
REGISTER
(
ConcatInfo
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/frontend/parallel/node_check.cc
浏览文件 @
f4bb43bb
...
...
@@ -24,7 +24,6 @@
namespace
mindspore
{
namespace
parallel
{
const
std
::
set
<
std
::
string
>
BLACK_LIST
=
{
TUPLE_GETITEM
,
MAKE_TUPLE
,
J
,
LIST_GETITEM
,
ARRAY_GETITEM
,
...
...
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
0 → 100644
浏览文件 @
f4bb43bb
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/concat_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace
mindspore
{
namespace
parallel
{
Status
ConcatInfo
::
GetAttrs
()
{
int
axis
=
0
;
auto
axis_iter
=
attrs_
.
find
(
AXIS
);
if
(
axis_iter
!=
attrs_
.
end
())
{
MS_EXCEPTION_IF_NULL
(
axis_iter
->
second
);
if
(
axis_iter
->
second
->
isa
<
Int32Imm
>
())
{
axis
=
axis_iter
->
second
->
cast
<
Int32ImmPtr
>
()
->
value
();
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The value of axis is not int"
;
return
FAILED
;
}
}
else
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Can not find the axis attr"
;
return
FAILED
;
}
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs shape is empty"
;
return
FAILED
;
}
int
dim
=
SizeToInt
(
inputs_shape_
[
0
].
size
());
if
(
axis
<
0
)
{
axis
=
axis
+
dim
;
}
axis_
=
SizeToInt
(
axis
);
return
SUCCESS
;
}
Status
ConcatInfo
::
CheckStrategy
(
const
StrategyPtr
&
strategy
)
{
MS_EXCEPTION_IF_NULL
(
strategy
);
if
(
CheckStrategyValue
(
strategy
,
inputs_shape_
,
is_auto_parallel_
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid strategy"
;
return
FAILED
;
}
std
::
vector
<
Dimensions
>
stra
=
strategy
->
GetInputDim
();
if
(
stra
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The strategy is empty"
;
return
FAILED
;
}
if
(
stra
.
size
()
!=
inputs_shape_
.
size
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of strategy must be equal to the size of inputs shape"
;
return
FAILED
;
}
for
(
size_t
i
=
0
;
i
<
stra
.
size
();
++
i
)
{
auto
strategy_ele
=
stra
[
i
];
auto
input_shape_ele
=
inputs_shape_
[
i
];
if
(
strategy_ele
.
size
()
!=
input_shape_ele
.
size
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The size of strategy element must be equal to the size of input shape"
;
return
FAILED
;
}
if
(
axis_
>=
strategy_ele
.
size
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The axis is out of range, the axis is "
<<
axis_
;
return
FAILED
;
}
if
(
strategy_ele
[
axis_
]
!=
1
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The axis can not be split"
;
return
FAILED
;
}
for
(
size_t
j
=
0
;
j
<
strategy_ele
.
size
();
++
j
)
{
if
(
strategy_ele
[
j
]
!=
stra
[
0
][
j
])
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The strategy of each input tensor must be equal"
;
return
FAILED
;
}
}
}
return
SUCCESS
;
}
Status
ConcatInfo
::
InferDevMatrixShape
()
{
MS_EXCEPTION_IF_NULL
(
strategy_
);
std
::
vector
<
Dimensions
>
stra
=
strategy_
->
GetInputDim
();
if
(
stra
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
"The strategy is empty"
;
return
FAILED
;
}
dev_matrix_shape_
=
stra
[
0
];
return
SUCCESS
;
}
Status
ConcatInfo
::
InferTensorMap
()
{
TensorMap
tensor_map
;
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
"The inputs shape is empty"
;
return
FAILED
;
}
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
int32_t
size
=
SizeToInt
(
inputs_shape_
[
0
].
size
());
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
tensor_map
.
push_back
(
size
-
i
-
1
);
}
for
(
size_t
i
=
0
;
i
<
inputs_shape_
.
size
();
++
i
)
{
inputs_tensor_map_
.
push_back
(
tensor_map
);
}
outputs_tensor_map_
.
push_back
(
tensor_map
);
return
SUCCESS
;
}
Status
ConcatInfo
::
InferMirrorOps
()
{
mirror_ops_
.
clear
();
if
(
inputs_tensor_map_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs tensor map is empty"
;
return
FAILED
;
}
Shape
input_tensor_map
=
inputs_tensor_map_
[
0
];
std
::
vector
<
Group
>
group
;
if
(
CreateGroupByTensorMap
(
input_tensor_map
,
&
group
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Create group for input failed."
;
return
FAILED
;
}
if
(
group
.
empty
())
{
MS_LOG
(
INFO
)
<<
name_
<<
": The mirror group is empty."
;
return
SUCCESS
;
}
OperatorVector
input_op
;
input_op
=
CreateMirrorOps
(
group
[
0
].
name
(),
group
[
0
].
GetDevNum
());
for
(
size_t
i
=
0
;
i
<
inputs_shape_
.
size
();
++
i
)
{
mirror_ops_
.
push_back
(
input_op
);
}
return
SUCCESS
;
}
Status
ConcatInfo
::
InferTensorInfo
()
{
if
(
inputs_shape_
.
empty
()
||
outputs_shape_
.
empty
()
||
inputs_tensor_map_
.
empty
()
||
outputs_tensor_map_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Invalid args"
;
return
FAILED
;
}
TensorLayout
input_layout
,
output_layout
;
for
(
size_t
i
=
0
;
i
<
inputs_shape_
.
size
();
++
i
)
{
// infer tensor layout
if
(
input_layout
.
InitFromVector
(
dev_matrix_shape_
,
inputs_tensor_map_
[
i
],
inputs_shape_
[
i
])
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer input tensor layout failed."
;
return
FAILED
;
}
TensorInfo
input_tensor_info
(
input_layout
);
inputs_tensor_info_
.
push_back
(
input_tensor_info
);
}
if
(
output_layout
.
InitFromVector
(
dev_matrix_shape_
,
outputs_tensor_map_
[
0
],
outputs_shape_
[
0
])
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer output tensor layout failed."
;
return
FAILED
;
}
TensorInfo
output_tensor_info
(
output_layout
);
outputs_tensor_info_
.
push_back
(
output_tensor_info
);
return
SUCCESS
;
}
void
ConcatInfo
::
ReComputeBatchSplitFlagList
()
{
for
(
size_t
i
=
0
;
i
<
inputs_shape_
.
size
();
i
++
)
{
split_flag_list_
[
i
]
=
true
;
}
}
Status
ConcatInfo
::
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
{
if
(
SetCostUnderStrategyBase
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Set cost under strategy failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
ConcatInfo
::
GenerateStrategies
(
int32_t
stage_id
)
{
if
(
InferAttrs
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Infer attrs failed"
;
return
FAILED
;
}
if
(
inputs_shape_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
name_
<<
": The inputs shape is empty"
;
return
FAILED
;
}
Shape
input_split
;
for
(
size_t
i
=
0
;
i
<
inputs_shape_
[
0
].
size
();
++
i
)
{
if
(
i
==
axis_
)
{
input_split
.
push_back
(
0
);
}
else
{
input_split
.
push_back
(
1
);
}
}
Shapes
splittable_inputs
;
for
(
size_t
i
=
0
;
i
<
inputs_shape_
.
size
();
++
i
)
{
splittable_inputs
.
push_back
(
input_split
);
}
std
::
vector
<
StrategyPtr
>
sp_vector
;
is_auto_parallel_
=
true
;
if
(
GenerateStrategiesWithBroadcast
(
stage_id
,
inputs_shape_
,
splittable_inputs
,
&
sp_vector
)
!=
SUCCESS
)
{
return
FAILED
;
}
size_t
success
=
0
;
for
(
auto
&
sp
:
sp_vector
)
{
PrintStrategy
(
sp
);
if
(
SetCostUnderStrategy
(
sp
)
==
SUCCESS
)
{
success
++
;
MS_LOG
(
INFO
)
<<
name_
<<
": Successfully generated "
<<
success
<<
" strategy."
;
PrintStrategy
(
sp
);
}
}
return
SUCCESS
;
}
Status
ConcatInfo
::
Init
(
const
StrategyPtr
&
strategy
)
{
if
(
InitWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init failed."
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": Init success."
;
return
SUCCESS
;
}
Status
ConcatInfo
::
InitForCostModel
(
const
StrategyPtr
&
strategy
)
{
if
(
InitForCostModelWithAutoRepeatCalc
(
strategy
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
": Init for cost model failed."
;
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
name_
<<
": Init for cost model success."
;
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h
0 → 100644
浏览文件 @
f4bb43bb
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace
mindspore
{
namespace
parallel
{
class
ConcatInfo
:
public
OperatorInfo
{
public:
ConcatInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ConcatCost
>
(
false
))
{}
~
ConcatInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
Status
GenerateStrategies
(
int32_t
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
)
override
;
void
ReComputeBatchSplitFlagList
()
override
;
protected:
Status
GetAttrs
()
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
InferMirrorOps
()
override
;
Status
InferForwardCommunication
()
override
{
return
SUCCESS
;
}
Status
InferTensorInfo
()
override
;
Status
InferDevMatrixShape
()
override
;
Status
InferTensorMap
()
override
;
private:
size_t
axis_
=
0
;
};
using
ConcatInfoPtr
=
std
::
shared_ptr
<
ConcatInfo
>
;
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h
浏览文件 @
f4bb43bb
...
...
@@ -39,5 +39,6 @@
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
#include "frontend/parallel/ops_info/tile_info.h"
#include "frontend/parallel/ops_info/strided_slice_info.h"
#include "frontend/parallel/ops_info/concat_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
浏览文件 @
f4bb43bb
...
...
@@ -118,6 +118,9 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
std
::
vector
<
bool
>
ExtractInputParameterByNode
(
const
CNodePtr
&
node
)
{
std
::
vector
<
bool
>
is_parameter
;
std
::
vector
<
AnfNodePtr
>
node_inputs
{
node
->
inputs
()};
if
((
node_inputs
.
size
()
==
2
)
&&
AnfNodeIsPrimitive
(
node_inputs
[
1
],
MAKE_TUPLE
))
{
node_inputs
=
node_inputs
[
1
]
->
cast
<
CNodePtr
>
()
->
inputs
();
}
for
(
size_t
i
=
1
;
i
<
node_inputs
.
size
();
++
i
)
{
auto
input
=
node_inputs
[
i
];
...
...
@@ -192,6 +195,10 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
std
::
vector
<
size_t
>
inputs_type_len
;
std
::
vector
<
AnfNodePtr
>
node_inputs
{
node
->
inputs
()};
if
((
node_inputs
.
size
()
==
2
)
&&
AnfNodeIsPrimitive
(
node_inputs
[
1
],
MAKE_TUPLE
))
{
node_inputs
=
node_inputs
[
1
]
->
cast
<
CNodePtr
>
()
->
inputs
();
}
// extract input element length
for
(
auto
&
input
:
node_inputs
)
{
if
(
IsValueNode
<
RefKey
>
(
input
))
{
...
...
@@ -255,7 +262,7 @@ bool IsSplittableOperator(const std::string &op_name) {
FLOORDIV
,
L2_NORMALIZE
,
TENSOR_ADD
,
MAXPOOL
,
MAXPOOLV2
,
VIRTUAL_DATA_SET
,
RELU
,
ONEHOT
,
DROPOUT_DO_MASK
,
REDUCE_MAX
,
REDUCE_MIN
,
ARGMAXWITHVALUE
,
ARGMINWITHVALUE
,
REDUCE_SUM
,
CONV2D
,
FUSE_BATCH_NORM
,
POOLING
,
MAX_POOL_WITH_ARGMAX
,
SIMPLE_MEAN
,
FLATTEN
,
BATCH_NORM
,
LAYER_NORM
,
BIAS_ADD
,
ASSIGN_SUB
,
COS
,
ACOS
,
EXP
,
LOG
,
REDUCE_MEAN
,
REAL_DIV
,
SIGMOID
,
POW
,
MAXIMUM
,
MINIMUM
,
EQUAL
,
NOT_EQUAL
,
LOGICALNOT
,
GATHERV2
,
SQRT
,
LOG
,
REDUCE_MEAN
,
REAL_DIV
,
SIGMOID
,
POW
,
MAXIMUM
,
MINIMUM
,
EQUAL
,
NOT_EQUAL
,
LOGICALNOT
,
GATHERV2
,
SQRT
,
CONCAT
,
STRIDEDSLICE
,
GET_NEXT
,
CAST
,
NEG
,
SQUARE
,
BATCH_MATMUL
,
EXPAND_DIMS
,
SQUEEZE
,
SPARSE_GATHERV2
,
TILE
,
DROPOUT
,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
,
SIGMOID_CROSS_ENTROPY_WITH_LOGITS
,
SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS
};
// clang-format on
...
...
@@ -275,7 +282,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
return
false
;
}
bool
bool_result
=
IsParallelCareNode
(
cnode
)
&&
!
IsSplittableOperator
(
prim
->
name
());
if
(
bool_result
)
{
if
(
bool_result
&&
(
prim
->
name
()
!=
MAKE_TUPLE
)
)
{
MS_LOG
(
EXCEPTION
)
<<
"Should implementing OperatorInfo for: "
<<
prim
->
name
();
}
else
if
(
prim
->
name
()
==
CAST
)
{
if
(
cnode
->
fullname_with_scope
().
find
(
OPTIMIZER_SUB_STRING
)
!=
std
::
string
::
npos
)
{
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
f4bb43bb
...
...
@@ -267,6 +267,33 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
return
tensorinfo_in
.
tensor_layout
();
}
bool
AnfNodeIsPrimitive
(
const
AnfNodePtr
&
anf_node
,
const
std
::
string
&
prim_name
)
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
auto
cnode
=
anf_node
->
cast
<
CNodePtr
>
();
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
return
false
;
}
auto
value_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
auto
prim
=
GetValueNode
<
PrimitivePtr
>
(
value_node
);
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
name
()
==
prim_name
)
{
return
true
;
}
return
false
;
}
std
::
string
GetPrimName
(
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
IsValueNode
<
Primitive
>
(
node
->
input
(
0
)))
{
MS_LOG
(
EXCEPTION
)
<<
"The node is not a primitive"
;
}
auto
value_node
=
node
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
auto
prim
=
GetValueNode
<
PrimitivePtr
>
(
value_node
);
MS_EXCEPTION_IF_NULL
(
prim
);
return
prim
->
name
();
}
OperatorInfoPtr
GetDistributeOperator
(
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
IsParallelCareNode
(
node
))
{
...
...
@@ -274,7 +301,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
}
OperatorInfoPtr
distribute_operator
=
node
->
user_data
<
OperatorInfo
>
();
if
(
distribute_operator
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"
GetDistributeOperator:distribute_operator is nullptr"
;
MS_LOG
(
EXCEPTION
)
<<
"
Distribute operator is nullptr, the prim is "
<<
GetPrimName
(
node
)
;
}
return
distribute_operator
;
}
...
...
@@ -423,6 +450,11 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
MS_EXCEPTION_IF_NULL
(
manager
);
AnfNodeIndexSet
node_set
=
manager
->
node_users
()[
node
];
CNodePtr
insert_node_new
;
if
(
AnfNodeIsPrimitive
(
node
,
MAKE_TUPLE
))
{
MS_LOG
(
INFO
)
<<
"No need to insert redistribution op betweend make_tuple node and the next node"
;
return
;
}
if
(
IsValueNode
<
Primitive
>
(
node
->
input
(
0
)))
{
auto
current_value
=
node
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
current_value
);
...
...
@@ -875,9 +907,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL
(
func_graph
);
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
if
((
node
->
inputs
().
size
()
==
2
)
&&
AnfNodeIsPrimitive
(
node
->
input
(
1
),
MAKE_TUPLE
))
{
MS_LOG
(
INFO
)
<<
"The mirror for "
<<
GetPrimName
(
node
)
<<
" has handle by make_tuple node"
;
return
;
}
if
(
mirror_ops
.
size
()
!=
node_size
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"
Failure:Mirrorops's size is wrong! mirror_ops size is "
<<
mirror_ops
.
size
()
<<
", node_size is "
<<
node_size
;
MS_LOG
(
EXCEPTION
)
<<
"
Mirrorops's size is wrong! mirror_ops size is "
<<
mirror_ops
.
size
()
<<
", node_size is "
<<
node_size
-
1
;
}
for
(
size_t
index
=
1
;
index
<
node_size
;
++
index
)
{
OperatorVector
backward_op
=
mirror_ops
[
index
-
1
];
...
...
@@ -993,7 +1031,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
const
std
::
vector
<
Shapes
>
&
shape_list
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
OperatorInfoPtr
operator_
=
OperatorInstanceByName
(
prim
->
name
(),
attrs
,
shape_list
);
if
(
operator_
==
nullptr
)
{
if
(
(
operator_
==
nullptr
)
&&
(
prim
->
name
()
!=
MAKE_TUPLE
)
)
{
MS_LOG
(
INFO
)
<<
"Creat "
<<
prim
->
name
()
<<
" failed, use batch parallel"
;
operator_
=
OperatorInstanceByName
(
BATCH_PARALLEL
,
attrs
,
shape_list
);
MS_EXCEPTION_IF_NULL
(
operator_
);
...
...
@@ -1177,7 +1215,12 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
continue
;
}
if
(
input_shapes
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"ExtractShape:Get input shape failed"
;
if
(
inputs_size
==
2
)
{
// like concat
shape_inputs
=
input_shapes
;
break
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"ExtractShape: Get input shape failed"
;
}
}
shape_inputs
.
push_back
(
input_shapes
[
0
]);
}
...
...
@@ -1269,8 +1312,8 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
}
TensorInfo
tensorinfo_in
=
distribute_operator
->
inputs_tensor_info
()[
IntToSize
(
res
.
second
-
1
)];
Shape
slice_shape
=
tensorinfo_in
.
slice_shape
();
MS_LOG
(
DEBUG
)
<<
"SetParallelShape slice_shape "
<<
parameter
->
ToString
()
<<
" shape "
<<
MakeValue
(
slice_shape
)
->
ToString
();
MS_LOG
(
INFO
)
<<
"SetParallelShape slice_shape "
<<
parameter
->
ToString
()
<<
" shape "
<<
MakeValue
(
slice_shape
)
->
ToString
()
<<
", op name is "
<<
distribute_operator
->
name
();
std
::
shared_ptr
<
abstract
::
BaseShape
>
parallel_shape
=
std
::
make_shared
<
abstract
::
Shape
>
(
slice_shape
);
MS_EXCEPTION_IF_NULL
(
parallel_shape
);
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
...
...
@@ -1450,6 +1493,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
SetVirtualDatasetStrategy
(
cnode
);
ValueNodePtr
prim_anf_node
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
();
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
prim_anf_node
);
if
(
prim
->
name
()
==
MAKE_TUPLE
)
{
continue
;
}
auto
attrs
=
prim
->
attrs
();
MS_LOG
(
INFO
)
<<
"extract information: node: "
<<
node
->
ToString
()
<<
" prim "
<<
prim
->
name
();
if
(
IsParallelCareNode
(
cnode
))
{
...
...
@@ -2045,13 +2091,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
// the make_tuple is parallel care node, but it may have not operator info
if
(
!
IsParallelCareNode
(
cnode
)
||
!
cnode
->
has_user_data
<
OperatorInfo
>
())
{
continue
;
}
OperatorInfoPtr
distribute_operator
=
GetDistributeOperator
(
cnode
);
if
(
distribute_operator
==
nullptr
)
{
continue
;
}
MS_EXCEPTION_IF_NULL
(
distribute_operator
);
// insert forward ops
InsertForwardOps
(
distribute_operator
,
cnode
);
...
...
@@ -2074,13 +2120,12 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
!
Is
ValueNode
<
Primitive
>
(
cnode
->
input
(
0
)
))
{
if
(
!
Is
ParallelCareNode
(
cnode
)
||
!
cnode
->
has_user_data
<
OperatorInfo
>
(
))
{
continue
;
}
OperatorInfoPtr
distribute_operator
=
GetDistributeOperator
(
cnode
);
if
(
distribute_operator
==
nullptr
)
{
continue
;
}
MS_EXCEPTION_IF_NULL
(
distribute_operator
);
// StepReplace
StepReplace
(
distribute_operator
,
cnode
);
}
...
...
@@ -2330,6 +2375,44 @@ Status ParallelInit() {
return
SUCCESS
;
}
void
HandleForwardMakeTuple
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
for
(
auto
&
node
:
all_nodes
)
{
if
(
!
AnfNodeIsPrimitive
(
node
,
MAKE_TUPLE
))
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
cnode
->
in_forward_flag
())
{
continue
;
}
FuncGraphManagerPtr
manager
=
cnode
->
func_graph
()
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
make_tuple_user
=
manager
->
node_users
()[
cnode
];
if
(
make_tuple_user
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Now the make_tuple's user must be 1, but got "
<<
make_tuple_user
.
size
();
}
CNodePtr
make_tuple_next_cnode
=
make_tuple_user
.
pop
().
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
make_tuple_next_cnode
);
std
::
string
make_tuple_user_prim_name
=
GetPrimName
(
make_tuple_next_cnode
);
if
(
!
IsParallelCareNode
(
make_tuple_next_cnode
))
{
MS_LOG
(
INFO
)
<<
"The make_tuple's user is "
<<
make_tuple_user_prim_name
<<
", no need to set operator info"
;
continue
;
}
if
(
make_tuple_next_cnode
->
inputs
().
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Now the make_tuple's user only support 1 input, but got "
<<
make_tuple_next_cnode
->
inputs
().
size
()
-
1
;
}
MS_LOG
(
INFO
)
<<
"Set the make_tuple's operator info, and the op name is "
<<
make_tuple_user_prim_name
;
OperatorInfoPtr
op_info
=
GetDistributeOperator
(
make_tuple_next_cnode
);
MS_EXCEPTION_IF_NULL
(
op_info
);
cnode
->
set_user_data
<
OperatorInfo
>
(
op_info
);
}
}
bool
StepParallel
(
const
FuncGraphPtr
&
root
,
const
opt
::
OptimizerPtr
&
optimizer
)
{
MS_EXCEPTION_IF_NULL
(
root
);
MS_EXCEPTION_IF_NULL
(
optimizer
);
...
...
@@ -2383,6 +2466,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
ExtractInformation
(
all_nodes
);
ReshapeInit
(
all_nodes
);
}
HandleForwardMakeTuple
(
all_nodes
);
// save strategy as checkpoint for multi-train
if
(
StrategyCheckpoint
::
GetInstance
().
SaveCheckPointOn
())
{
CheckpointStrategy
(
root
);
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.h
浏览文件 @
f4bb43bb
...
...
@@ -149,6 +149,8 @@ Status ParallelInit();
std
::
vector
<
std
::
string
>
ExtractInputsTensorName
(
const
CNodePtr
&
node
);
std
::
set
<
FuncGraphPtr
>
ForwardGraph
(
const
FuncGraphPtr
&
root
);
bool
AnfNodeIsPrimitive
(
const
AnfNodePtr
&
anf_node
,
const
std
::
string
&
prim_name
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ops/_grad/grad_comm_ops.py
浏览文件 @
f4bb43bb
...
...
@@ -222,9 +222,17 @@ def get_bprop_virtual_div_operator(self):
dtype
=
P
.
DType
()
def
bprop
(
x
,
out
,
dout
):
if
F
.
issubclass_
(
F
.
dtype
(
dout
),
mstype
.
bool_
):
return
(
dout
,)
dx
=
op
(
dout
,
cast
(
F
.
scalar_to_array
(
divisor
),
dtype
(
dout
)))
if
F
.
issubclass_
(
F
.
typeof
(
dout
),
mstype
.
tensor
):
if
F
.
issubclass_
(
F
.
dtype
(
dout
),
mstype
.
bool_
):
return
(
dout
,)
dx
=
op
(
dout
,
cast
(
F
.
scalar_to_array
(
divisor
),
dtype
(
dout
)))
return
(
dx
,)
dx
=
()
input_nums
=
F
.
tuple_len
(
dout
)
for
i
in
range
(
input_nums
):
ele_grad
=
op
(
dout
[
i
],
cast
(
F
.
scalar_to_array
(
divisor
),
dtype
(
dout
[
i
])))
dx
=
dx
+
(
ele_grad
,)
return
(
dx
,)
return
bprop
...
...
tests/ut/python/parallel/test_concat.py
0 → 100644
浏览文件 @
f4bb43bb
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
context
,
Tensor
,
Parameter
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
Cell
,
TrainOneStepCell
,
Momentum
from
mindspore.ops
import
operations
as
P
class
Net
(
Cell
):
def
__init__
(
self
,
weight
,
weight2
,
strategy1
=
None
,
strategy2
=
None
,
is_parameter
=
True
):
super
().
__init__
()
self
.
concat
=
P
.
Concat
(
axis
=
0
).
set_strategy
(
strategy1
)
if
is_parameter
:
self
.
weight
=
Parameter
(
weight
,
"w1"
)
else
:
self
.
weight
=
weight
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
weight2
=
Parameter
(
weight2
,
"w2"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
concat
((
self
.
weight
,
self
.
weight2
))
out
=
self
.
mul
(
x
,
out
)
return
out
class
Net2
(
Cell
):
def
__init__
(
self
,
weight
,
strategy1
=
None
,
strategy2
=
None
,
axis
=
0
):
super
().
__init__
()
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy1
)
self
.
concat
=
P
.
Concat
(
axis
=
axis
).
set_strategy
(
strategy2
)
self
.
weight
=
Parameter
(
weight
,
"w"
)
def
construct
(
self
,
x
,
b
):
out
=
self
.
mul
(
x
,
b
)
out
=
self
.
concat
((
out
,
self
.
weight
))
return
out
_x
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w1
=
Tensor
(
np
.
ones
([
96
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w2
=
Tensor
(
np
.
ones
([
32
,
64
,
32
]),
dtype
=
ms
.
float32
)
_w3
=
Tensor
(
np
.
ones
([
128
,
16
,
32
]),
dtype
=
ms
.
float32
)
_b
=
Tensor
(
np
.
ones
([
128
,
64
,
32
]),
dtype
=
ms
.
float32
)
def
compile_net
(
net
):
context
.
set_context
(
save_graphs
=
True
)
optimizer
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
train_net
=
TrainOneStepCell
(
net
,
optimizer
)
train_net
.
set_auto_parallel
()
_executor
.
compile
(
train_net
,
_x
,
_b
)
context
.
reset_auto_parallel_context
()
def
test_concat_parameter
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
4
,
2
),
(
1
,
4
,
2
))
strategy2
=
((
1
,
4
,
2
),
(
1
,
4
,
2
))
net
=
Net
(
_w1
,
_w2
,
strategy1
,
strategy2
,
is_parameter
=
True
)
compile_net
(
net
)
def
test_concat_parameter_no_full_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
2
,
2
),
(
1
,
2
,
2
))
strategy2
=
((
1
,
4
,
2
),
(
1
,
4
,
2
))
net
=
Net
(
_w1
,
_w2
,
strategy1
,
strategy2
,
is_parameter
=
True
)
compile_net
(
net
)
def
test_concat_tensor_and_parameter
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
1
,
2
,
2
),
(
1
,
2
,
2
))
strategy2
=
((
1
,
4
,
2
),
(
1
,
4
,
2
))
net
=
Net
(
_w1
,
_w2
,
strategy1
,
strategy2
,
is_parameter
=
False
)
compile_net
(
net
)
def
test_concat_output
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
((
1
,
4
,
2
),
(
1
,
4
,
2
))
net
=
Net2
(
_w1
,
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_concat_output_no_full_split
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
((
1
,
2
,
2
),
(
1
,
2
,
2
))
net
=
Net2
(
_w1
,
strategy1
,
strategy2
)
compile_net
(
net
)
def
test_concat_no_strategy
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
,
2
),
(
2
,
2
,
2
))
strategy2
=
None
net
=
Net2
(
_w3
,
strategy1
,
strategy2
,
axis
=
1
)
compile_net
(
net
)
def
test_concat_auto_parallel
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
net
=
Net2
(
_w2
)
compile_net
(
net
)
def
test_concat_auto_parallel2
():
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
,
device_num
=
8
,
global_rank
=
0
)
strategy1
=
None
strategy2
=
None
net
=
Net2
(
_w3
,
strategy1
,
strategy2
,
axis
=
1
)
compile_net
(
net
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录