Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b0358901
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b0358901
编写于
9月 04, 2020
作者:
W
wangmin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
allreduce fusion
上级
5a2e1268
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
148 addition
and
247 deletion
+148
-247
model_zoo/official/cv/resnet_thor/README.md
model_zoo/official/cv/resnet_thor/README.md
+1
-1
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
+64
-109
model_zoo/official/cv/resnet_thor/src/thor.py
model_zoo/official/cv/resnet_thor/src/thor.py
+12
-9
model_zoo/official/cv/resnet_thor/train.py
model_zoo/official/cv/resnet_thor/train.py
+1
-5
tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py
...st/networks/models/resnet50/src_thor/grad_reducer_thor.py
+64
-114
tests/st/networks/models/resnet50/src_thor/thor.py
tests/st/networks/models/resnet50/src_thor/thor.py
+5
-4
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
+1
-5
未找到文件。
model_zoo/official/cv/resnet_thor/README.md
浏览文件 @
b0358901
...
@@ -217,7 +217,7 @@ Inference result will be stored in the example path, whose folder name is "eval"
...
@@ -217,7 +217,7 @@ Inference result will be stored in the example path, whose folder name is "eval"
```
```
Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
```
```
result: {'top_5_accuracy': 0.928
6771766965429, 'top_1_accuracy': 0.7613036171574904
} ckpt=train_parallel/resnet-36_5004.ckpt
result: {'top_5_accuracy': 0.928
7972151088348, 'top_1_accuracy': 0.7597031049935979
} ckpt=train_parallel/resnet-36_5004.ckpt
```
```
## Model Description
## Model Description
...
...
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
浏览文件 @
b0358901
...
@@ -12,149 +12,109 @@
...
@@ -12,149 +12,109 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""grad_reducer_thor"""
"""grad reducer cell for distributed training"""
import
mindspore.common.dtype
as
mstype
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.cell
import
Cell
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
ReduceOp
from
mindspore.ops.operations.comm_ops
import
AllReduce
import
mindspore.common.dtype
as
mstype
reduce_opt
=
C
.
MultitypeFuncGraph
(
"reduce_opt"
)
reduce_opt
=
C
.
MultitypeFuncGraph
(
"reduce_opt"
)
_all_reduce_A
=
AllReduce
()
def
_init_allreduce_operators
(
length
,
split_indices
):
""" initialize allreduce communication operators"""
indices
=
split_indices
[
0
]
fusion
=
split_indices
[
1
]
op_list
=
()
j
=
0
for
i
in
range
(
length
):
if
j
<=
len
(
indices
)
-
1
:
temp
=
indices
[
j
]
else
:
temp
=
length
if
i
>=
temp
:
j
=
j
+
1
fusion
=
fusion
+
1
op
=
AllReduce
(
'sum'
,
GlobalComm
.
WORLD_COMM_GROUP
)
op
.
add_prim_attr
(
'fusion'
,
fusion
)
op_list
=
op_list
+
(
op
,)
return
op_list
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Function"
,
"Tensor"
)
def
_tensors_allreduce_mean
(
mul
,
degree
,
allreduce
,
parameters
):
"""
Apply allreduce on parameters.
def
_init_optimizer_allreduce
(
group
)
:
Args
:
global
_all_reduce_A
mul(Primitive): The mul operator for parameters.
_all_reduce_A
=
AllReduce
(
ReduceOp
.
SUM
,
GlobalComm
.
WORLD_COMM_GROUP
)
degree (int): The mean coefficient.
_all_reduce_A
.
add_prim_attr
(
'fusion'
,
group
)
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Tensor"
)
Returns:
def
_tensors_allreduce_mean
(
mul
,
degree
,
grad
):
Tensor, the parameters after operation.
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
grad
))
"""
grad
=
_all_reduce_A
(
grad
)
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
parameters
))
parameters
=
allreduce
(
parameters
)
cast_op
=
P
.
Cast
()
cast_op
=
P
.
Cast
()
return
mul
(
grad
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
grad
)))
return
mul
(
parameters
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
parameters
)))
@
reduce_opt
.
register
(
"Bool"
,
"Tensor"
)
def
_tensors_allreduce
(
allreduce_filter
,
grad
):
if
allreduce_filter
:
return
_all_reduce_A
(
grad
)
return
grad
_get_datatype
=
C
.
MultitypeFuncGraph
(
"_get_datatype"
)
_get_datatype
=
C
.
MultitypeFuncGraph
(
"_get_datatype"
)
@
_get_datatype
.
register
(
"Tensor"
)
@
_get_datatype
.
register
(
"Tensor"
)
def
_tensors_get_datatype
(
grad
):
def
_tensors_get_datatype
(
parameters
):
"""
"""
Acquire
gradient
datatype.
Acquire
parameters
datatype.
Args:
Args:
grad (Tensor): The gradient tensor
before operation.
parameters (Tensor): The parameters
before operation.
Returns:
Returns:
mstype, the datatype of
gradient
.
mstype, the datatype of
parameters
.
"""
"""
return
F
.
dtype
(
grad
)
return
F
.
dtype
(
parameters
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"Tensor"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"Tensor"
)
def
_tensors_cast_datatype
(
datatype
,
grad
):
def
_tensors_cast_datatype
(
datatype
,
parameters
):
"""
"""
Cast
gradient
to datatype.
Cast
parameters
to datatype.
Args:
Args:
datatype (mstype): the destination datatype of
gradient
.
datatype (mstype): the destination datatype of
parameters
.
grad (Tensor): The gradient tensor
before operation.
parameters (Tensor): The parameters
before operation.
Returns:
Returns:
Tensor, the
gradient tensor
after operation.
Tensor, the
parameters
after operation.
"""
"""
return
F
.
cast
(
grad
,
datatype
)
return
F
.
cast
(
parameters
,
datatype
)
class
DistributedGradReducerThor
(
Cell
):
class
DistributedGradReducerThor
(
Cell
):
"""
"""
A distributed optimizer.
A distributed optimizer.
Constructs a
gradient
reducer Cell, which applies communication and average operations on
Constructs a
parameters
reducer Cell, which applies communication and average operations on
single-process
gradient
values.
single-process
parameters
values.
Args:
Args:
parameters (list): the parameters to be updated.
parameter_length (int): length of the parameters to be updated.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
split_indices(tuple): parameter split indices.
mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
Raises:
Raises:
ValueError: If degree is not a int or less than 0.
ValueError: If degree is not a int or less than 0.
Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
"""
"""
def
__init__
(
self
,
parameter
s
,
group
,
mean
=
True
,
degree
=
None
):
def
__init__
(
self
,
parameter
_length
,
split_indices
,
mean
=
True
,
degree
=
None
):
super
(
DistributedGradReducerThor
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
DistributedGradReducerThor
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
hyper_map
=
C
.
HyperMap
()
self
.
mul
=
P
.
Mul
()
self
.
mul
=
P
.
Mul
()
...
@@ -165,16 +125,11 @@ class DistributedGradReducerThor(Cell):
...
@@ -165,16 +125,11 @@ class DistributedGradReducerThor(Cell):
raise
ValueError
(
"Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
)
raise
ValueError
(
"Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
)
self
.
degree
=
degree
self
.
degree
=
degree
self
.
mean
=
mean
self
.
mean
=
mean
self
.
allreduce_filter
=
tuple
(
x
.
layerwise_parallel
is
False
for
x
in
parameters
)
self
.
op_list
=
_init_allreduce_operators
(
parameter_length
,
split_indices
)
_init_optimizer_allreduce
(
group
)
def
construct
(
self
,
parameters
):
def
construct
(
self
,
grads
):
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
parameters
)
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
parameters
)
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
self
.
op_list
,
parameters
)
# and cast back after the operation.
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_parameters
)
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
grads
)
return
new_parameters
grads
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
grads
)
new_grad
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
grads
)
new_grad
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_grad
)
return
new_grad
model_zoo/official/cv/resnet_thor/src/thor.py
浏览文件 @
b0358901
...
@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
...
@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from
mindspore._checkparam
import
check_bool
from
mindspore._checkparam
import
check_bool
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore.nn.optim.optimizer
import
Optimizer
from
mindspore.nn.optim.optimizer
import
Optimizer
from
mindspore.parallel._utils
import
_get_device_num
,
_get_
gradients
_mean
from
mindspore.parallel._utils
import
_get_device_num
,
_get_
mirror
_mean
from
src.grad_reducer_thor
import
DistributedGradReducerThor
from
src.grad_reducer_thor
import
DistributedGradReducerThor
_momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
_momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
...
@@ -85,10 +85,12 @@ class THOR_GPU(Optimizer):
...
@@ -85,10 +85,12 @@ class THOR_GPU(Optimizer):
self
.
assign
=
P
.
Assign
()
self
.
assign
=
P
.
Assign
()
self
.
mul
=
P
.
Mul
()
self
.
mul
=
P
.
Mul
()
mean
=
_get_
gradients
_mean
()
mean
=
_get_
mirror
_mean
()
degree
=
_get_device_num
()
degree
=
_get_device_num
()
self
.
grad_reducer_thorA
=
DistributedGradReducerThor
(
self
.
parameters
,
0
,
mean
,
degree
)
self
.
grad_reducer_thorG
=
DistributedGradReducerThor
(
self
.
parameters
,
0
,
mean
,
degree
)
parameter_length
=
len
(
self
.
feature_map
)
self
.
grad_reducer_thorA
=
DistributedGradReducerThor
(
parameter_length
,
((
parameter_length
,),
0
),
mean
,
degree
)
self
.
grad_reducer_thorG
=
DistributedGradReducerThor
(
parameter_length
,
((
parameter_length
,),
0
),
mean
,
degree
)
self
.
weight_decay
=
weight_decay
self
.
weight_decay
=
weight_decay
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
update_gradient
=
P
.
UpdateThorGradient
(
split_dim
=
128
)
self
.
update_gradient
=
P
.
UpdateThorGradient
(
split_dim
=
128
)
...
@@ -191,12 +193,13 @@ class THOR(Optimizer):
...
@@ -191,12 +193,13 @@ class THOR(Optimizer):
1.0
/
196
,
1.0
/
196
,
1.0
/
196
,
1.0
/
196
,
1.0
/
196
,
1.0
/
196
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
]
1.0
]
mean
=
_get_
gradients
_mean
()
mean
=
_get_
mirror
_mean
()
degree
=
_get_device_num
()
degree
=
_get_device_num
()
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
self
.
parameters
,
2
,
mean
,
degree
)
parameter_length
=
len
(
self
.
feature_map
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
self
.
parameters
,
5
,
mean
,
degree
)
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
2
),
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
self
.
parameters
,
3
,
mean
,
degree
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
4
),
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
self
.
parameters
,
4
,
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
6
),
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
8
),
mean
,
degree
)
self
.
matrix_A_inv
=
()
self
.
matrix_A_inv
=
()
self
.
matrix_G_inv
=
()
self
.
matrix_G_inv
=
()
self
.
matrix_max_inv
=
()
self
.
matrix_max_inv
=
()
...
...
model_zoo/official/cv/resnet_thor/train.py
浏览文件 @
b0358901
...
@@ -95,11 +95,7 @@ if __name__ == '__main__':
...
@@ -95,11 +95,7 @@ if __name__ == '__main__':
context
.
set_context
(
device_id
=
device_id
,
enable_auto_mixed_precision
=
True
)
context
.
set_context
(
device_id
=
device_id
,
enable_auto_mixed_precision
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
)
gradients_mean
=
True
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
],
"hccl_world_groupsum1"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
])
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum2"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum3"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum4"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum5"
)
init
()
init
()
# GPU target
# GPU target
else
:
else
:
...
...
tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py
浏览文件 @
b0358901
...
@@ -12,150 +12,109 @@
...
@@ -12,150 +12,109 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""grad_reducer_thor"""
"""grad reducer cell for distributed training"""
import
mindspore.common.dtype
as
mstype
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.cell
import
Cell
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
ReduceOp
from
mindspore.ops.operations.comm_ops
import
AllReduce
import
mindspore.common.dtype
as
mstype
reduce_opt
=
C
.
MultitypeFuncGraph
(
"reduce_opt"
)
reduce_opt
=
C
.
MultitypeFuncGraph
(
"reduce_opt"
)
_all_reduce_A
=
AllReduce
()
def
_init_allreduce_operators
(
length
,
split_indices
):
""" initialize allreduce communication operators"""
indices
=
split_indices
[
0
]
fusion
=
split_indices
[
1
]
op_list
=
()
j
=
0
for
i
in
range
(
length
):
if
j
<=
len
(
indices
)
-
1
:
temp
=
indices
[
j
]
else
:
temp
=
length
if
i
>=
temp
:
j
=
j
+
1
fusion
=
fusion
+
1
op
=
AllReduce
(
'sum'
,
GlobalComm
.
WORLD_COMM_GROUP
)
op
.
add_prim_attr
(
'fusion'
,
fusion
)
op_list
=
op_list
+
(
op
,)
return
op_list
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Function"
,
"Tensor"
)
def
_tensors_allreduce_mean
(
mul
,
degree
,
allreduce
,
parameters
):
"""
Apply allreduce on parameters.
def
_init_optimizer_allreduce
(
group
)
:
Args
:
global
_all_reduce_A
mul(Primitive): The mul operator for parameters.
_all_reduce_A
=
AllReduce
(
ReduceOp
.
SUM
,
GlobalComm
.
WORLD_COMM_GROUP
)
degree (int): The mean coefficient.
_all_reduce_A
.
add_prim_attr
(
'fusion'
,
group
)
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Tensor"
)
Returns:
def
_tensors_allreduce_mean
(
mul
,
degree
,
grad
):
Tensor, the parameters after operation.
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
grad
))
"""
grad
=
_all_reduce_A
(
grad
)
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
parameters
))
parameters
=
allreduce
(
parameters
)
cast_op
=
P
.
Cast
()
cast_op
=
P
.
Cast
()
return
mul
(
grad
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
grad
)))
return
mul
(
parameters
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
parameters
)))
@
reduce_opt
.
register
(
"Bool"
,
"Tensor"
)
def
_tensors_allreduce
(
allreduce_filter
,
grad
):
if
allreduce_filter
:
return
_all_reduce_A
(
grad
)
return
grad
_get_datatype
=
C
.
MultitypeFuncGraph
(
"_get_datatype"
)
_get_datatype
=
C
.
MultitypeFuncGraph
(
"_get_datatype"
)
@
_get_datatype
.
register
(
"Tensor"
)
@
_get_datatype
.
register
(
"Tensor"
)
def
_tensors_get_datatype
(
grad
):
def
_tensors_get_datatype
(
parameters
):
"""
"""
Acquire
gradient
datatype.
Acquire
parameters
datatype.
Args:
Args:
grad (Tensor): The gradient tensor
before operation.
parameters (Tensor): The parameters
before operation.
Returns:
Returns:
mstype, the datatype of
gradient
.
mstype, the datatype of
parameters
.
"""
"""
return
F
.
dtype
(
grad
)
return
F
.
dtype
(
parameters
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"Tensor"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"Tensor"
)
def
_tensors_cast_datatype
(
datatype
,
grad
):
def
_tensors_cast_datatype
(
datatype
,
parameters
):
"""
"""
Cast
gradient
to datatype.
Cast
parameters
to datatype.
Args:
Args:
datatype (mstype): the destination datatype of
gradient
.
datatype (mstype): the destination datatype of
parameters
.
grad (Tensor): The gradient tensor
before operation.
parameters (Tensor): The parameters
before operation.
Returns:
Returns:
Tensor, the
gradient tensor
after operation.
Tensor, the
parameters
after operation.
"""
"""
return
F
.
cast
(
grad
,
datatype
)
return
F
.
cast
(
parameters
,
datatype
)
class
DistributedGradReducerThor
(
Cell
):
class
DistributedGradReducerThor
(
Cell
):
"""
"""
A distributed optimizer.
A distributed optimizer.
Constructs a
gradient
reducer Cell, which applies communication and average operations on
Constructs a
parameters
reducer Cell, which applies communication and average operations on
single-process
gradient
values.
single-process
parameters
values.
Args:
Args:
parameter
s (list):
the parameters to be updated.
parameter
_length (int): length of
the parameters to be updated.
group (int): the different group to allreduce
.
split_indices(tuple): parameter split indices
.
mean (bool): When mean is true, the mean coefficient (degree) would apply on
gradient
s. Default: False.
mean (bool): When mean is true, the mean coefficient (degree) would apply on
parameter
s. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
Raises:
Raises:
ValueError: If degree is not a int or less than 0.
ValueError: If degree is not a int or less than 0.
Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
"""
"""
def
__init__
(
self
,
parameter
s
,
group
,
mean
=
True
,
degree
=
None
):
def
__init__
(
self
,
parameter
_length
,
split_indices
,
mean
=
True
,
degree
=
None
):
super
(
DistributedGradReducerThor
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
DistributedGradReducerThor
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
hyper_map
=
C
.
HyperMap
()
self
.
mul
=
P
.
Mul
()
self
.
mul
=
P
.
Mul
()
...
@@ -166,20 +125,11 @@ class DistributedGradReducerThor(Cell):
...
@@ -166,20 +125,11 @@ class DistributedGradReducerThor(Cell):
raise
ValueError
(
"Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
)
raise
ValueError
(
"Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
)
self
.
degree
=
degree
self
.
degree
=
degree
self
.
mean
=
mean
self
.
mean
=
mean
self
.
allreduce_filter
=
tuple
(
x
.
layerwise_parallel
is
False
for
x
in
parameters
)
self
.
op_list
=
_init_allreduce_operators
(
parameter_length
,
split_indices
)
_init_optimizer_allreduce
(
group
)
def
construct
(
self
,
parameters
):
def
construct
(
self
,
grads
):
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
parameters
)
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
parameters
)
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
self
.
op_list
,
parameters
)
# and cast back after the operation.
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_parameters
)
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
grads
)
return
new_parameters
grads
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
grads
)
if
self
.
mean
:
new_grad
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
grads
)
else
:
new_grad
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
),
self
.
allreduce_filter
,
grads
)
new_grad
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_grad
)
return
new_grad
tests/st/networks/models/resnet50/src_thor/thor.py
浏览文件 @
b0358901
...
@@ -89,10 +89,11 @@ class THOR(Optimizer):
...
@@ -89,10 +89,11 @@ class THOR(Optimizer):
1.0
]
1.0
]
mean
=
_get_gradients_mean
()
mean
=
_get_gradients_mean
()
degree
=
_get_device_num
()
degree
=
_get_device_num
()
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
self
.
parameters
,
2
,
mean
,
degree
)
parameter_length
=
len
(
self
.
feature_map
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
self
.
parameters
,
5
,
mean
,
degree
)
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
2
),
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
self
.
parameters
,
3
,
mean
,
degree
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
4
),
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
self
.
parameters
,
4
,
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
6
),
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
8
),
mean
,
degree
)
self
.
matrix_A_inv
=
()
self
.
matrix_A_inv
=
()
self
.
matrix_G_inv
=
()
self
.
matrix_G_inv
=
()
self
.
matrix_max_inv
=
()
self
.
matrix_max_inv
=
()
...
...
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
浏览文件 @
b0358901
...
@@ -241,11 +241,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
...
@@ -241,11 +241,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
if
enable_hccl
:
if
enable_hccl
:
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
,
parameter_broadcast
=
True
)
gradients_mean
=
True
,
parameter_broadcast
=
True
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
],
"hccl_world_groupsum1"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
])
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum2"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum3"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum4"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum5"
)
init
()
init
()
# network
# network
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录