Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4e734650
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e734650
编写于
9月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5782 change allreduce fusion function
Merge pull request !5782 from wangmin0104/master
上级
28c42d55
b0358901
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
148 addition
and
247 deletion
+148
-247
model_zoo/official/cv/resnet_thor/README.md
model_zoo/official/cv/resnet_thor/README.md
+1
-1
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
+64
-109
model_zoo/official/cv/resnet_thor/src/thor.py
model_zoo/official/cv/resnet_thor/src/thor.py
+12
-9
model_zoo/official/cv/resnet_thor/train.py
model_zoo/official/cv/resnet_thor/train.py
+1
-5
tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py
...st/networks/models/resnet50/src_thor/grad_reducer_thor.py
+64
-114
tests/st/networks/models/resnet50/src_thor/thor.py
tests/st/networks/models/resnet50/src_thor/thor.py
+5
-4
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
+1
-5
未找到文件。
model_zoo/official/cv/resnet_thor/README.md
浏览文件 @
4e734650
...
...
@@ -217,7 +217,7 @@ Inference result will be stored in the example path, whose folder name is "eval"
```
Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
```
result: {'top_5_accuracy': 0.928
6771766965429, 'top_1_accuracy': 0.7613036171574904
} ckpt=train_parallel/resnet-36_5004.ckpt
result: {'top_5_accuracy': 0.928
7972151088348, 'top_1_accuracy': 0.7597031049935979
} ckpt=train_parallel/resnet-36_5004.ckpt
```
## Model Description
...
...
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
浏览文件 @
4e734650
...
...
@@ -12,149 +12,109 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad_reducer_thor"""
import
mindspore.common.dtype
as
mstype
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
"""grad reducer cell for distributed training"""
from
mindspore.nn.cell
import
Cell
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
ReduceOp
from
mindspore.ops.operations.comm_ops
import
AllReduce
import
mindspore.common.dtype
as
mstype
reduce_opt
=
C
.
MultitypeFuncGraph
(
"reduce_opt"
)
_all_reduce_A
=
AllReduce
()
def
_init_allreduce_operators
(
length
,
split_indices
):
""" initialize allreduce communication operators"""
indices
=
split_indices
[
0
]
fusion
=
split_indices
[
1
]
op_list
=
()
j
=
0
for
i
in
range
(
length
):
if
j
<=
len
(
indices
)
-
1
:
temp
=
indices
[
j
]
else
:
temp
=
length
if
i
>=
temp
:
j
=
j
+
1
fusion
=
fusion
+
1
op
=
AllReduce
(
'sum'
,
GlobalComm
.
WORLD_COMM_GROUP
)
op
.
add_prim_attr
(
'fusion'
,
fusion
)
op_list
=
op_list
+
(
op
,)
return
op_list
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Function"
,
"Tensor"
)
def
_tensors_allreduce_mean
(
mul
,
degree
,
allreduce
,
parameters
):
"""
Apply allreduce on parameters.
def
_init_optimizer_allreduce
(
group
)
:
global
_all_reduce_A
_all_reduce_A
=
AllReduce
(
ReduceOp
.
SUM
,
GlobalComm
.
WORLD_COMM_GROUP
)
_all_reduce_A
.
add_prim_attr
(
'fusion'
,
group
)
Args
:
mul(Primitive): The mul operator for parameters.
degree (int): The mean coefficient.
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Tensor"
)
def
_tensors_allreduce_mean
(
mul
,
degree
,
grad
):
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
grad
))
grad
=
_all_reduce_A
(
grad
)
Returns:
Tensor, the parameters after operation.
"""
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
parameters
))
parameters
=
allreduce
(
parameters
)
cast_op
=
P
.
Cast
()
return
mul
(
grad
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
grad
)))
@
reduce_opt
.
register
(
"Bool"
,
"Tensor"
)
def
_tensors_allreduce
(
allreduce_filter
,
grad
):
if
allreduce_filter
:
return
_all_reduce_A
(
grad
)
return
grad
return
mul
(
parameters
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
parameters
)))
_get_datatype
=
C
.
MultitypeFuncGraph
(
"_get_datatype"
)
@
_get_datatype
.
register
(
"Tensor"
)
def
_tensors_get_datatype
(
grad
):
def
_tensors_get_datatype
(
parameters
):
"""
Acquire
gradient
datatype.
Acquire
parameters
datatype.
Args:
grad (Tensor): The gradient tensor
before operation.
parameters (Tensor): The parameters
before operation.
Returns:
mstype, the datatype of
gradient
.
mstype, the datatype of
parameters
.
"""
return
F
.
dtype
(
grad
)
return
F
.
dtype
(
parameters
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"Tensor"
)
def
_tensors_cast_datatype
(
datatype
,
grad
):
def
_tensors_cast_datatype
(
datatype
,
parameters
):
"""
Cast
gradient
to datatype.
Cast
parameters
to datatype.
Args:
datatype (mstype): the destination datatype of
gradient
.
grad (Tensor): The gradient tensor
before operation.
datatype (mstype): the destination datatype of
parameters
.
parameters (Tensor): The parameters
before operation.
Returns:
Tensor, the
gradient tensor
after operation.
Tensor, the
parameters
after operation.
"""
return
F
.
cast
(
grad
,
datatype
)
return
F
.
cast
(
parameters
,
datatype
)
class
DistributedGradReducerThor
(
Cell
):
"""
A distributed optimizer.
Constructs a
gradient
reducer Cell, which applies communication and average operations on
single-process
gradient
values.
Constructs a
parameters
reducer Cell, which applies communication and average operations on
single-process
parameters
values.
Args:
parameters (list): the parameters to be updated.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
parameter_length (int): length of the parameters to be updated.
split_indices(tuple): parameter split indices.
mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
Raises:
ValueError: If degree is not a int or less than 0.
Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
"""
def
__init__
(
self
,
parameter
s
,
group
,
mean
=
True
,
degree
=
None
):
def
__init__
(
self
,
parameter
_length
,
split_indices
,
mean
=
True
,
degree
=
None
):
super
(
DistributedGradReducerThor
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
mul
=
P
.
Mul
()
...
...
@@ -165,16 +125,11 @@ class DistributedGradReducerThor(Cell):
raise
ValueError
(
"Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
)
self
.
degree
=
degree
self
.
mean
=
mean
self
.
allreduce_filter
=
tuple
(
x
.
layerwise_parallel
is
False
for
x
in
parameters
)
_init_optimizer_allreduce
(
group
)
def
construct
(
self
,
grads
):
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
grads
)
new_grad
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
grads
)
new_grad
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_grad
)
return
new_grad
self
.
op_list
=
_init_allreduce_operators
(
parameter_length
,
split_indices
)
def
construct
(
self
,
parameters
):
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
parameters
)
parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
parameters
)
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
self
.
op_list
,
parameters
)
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_parameters
)
return
new_parameters
model_zoo/official/cv/resnet_thor/src/thor.py
浏览文件 @
4e734650
...
...
@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from
mindspore._checkparam
import
check_bool
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore.nn.optim.optimizer
import
Optimizer
from
mindspore.parallel._utils
import
_get_device_num
,
_get_
gradients
_mean
from
mindspore.parallel._utils
import
_get_device_num
,
_get_
mirror
_mean
from
src.grad_reducer_thor
import
DistributedGradReducerThor
_momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
...
...
@@ -85,10 +85,12 @@ class THOR_GPU(Optimizer):
self
.
assign
=
P
.
Assign
()
self
.
mul
=
P
.
Mul
()
mean
=
_get_
gradients
_mean
()
mean
=
_get_
mirror
_mean
()
degree
=
_get_device_num
()
self
.
grad_reducer_thorA
=
DistributedGradReducerThor
(
self
.
parameters
,
0
,
mean
,
degree
)
self
.
grad_reducer_thorG
=
DistributedGradReducerThor
(
self
.
parameters
,
0
,
mean
,
degree
)
parameter_length
=
len
(
self
.
feature_map
)
self
.
grad_reducer_thorA
=
DistributedGradReducerThor
(
parameter_length
,
((
parameter_length
,),
0
),
mean
,
degree
)
self
.
grad_reducer_thorG
=
DistributedGradReducerThor
(
parameter_length
,
((
parameter_length
,),
0
),
mean
,
degree
)
self
.
weight_decay
=
weight_decay
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
update_gradient
=
P
.
UpdateThorGradient
(
split_dim
=
128
)
...
...
@@ -191,12 +193,13 @@ class THOR(Optimizer):
1.0
/
196
,
1.0
/
196
,
1.0
/
196
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
/
49
,
1.0
]
mean
=
_get_
gradients
_mean
()
mean
=
_get_
mirror
_mean
()
degree
=
_get_device_num
()
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
self
.
parameters
,
2
,
mean
,
degree
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
self
.
parameters
,
5
,
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
self
.
parameters
,
3
,
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
self
.
parameters
,
4
,
mean
,
degree
)
parameter_length
=
len
(
self
.
feature_map
)
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
2
),
mean
,
degree
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
4
),
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
6
),
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
8
),
mean
,
degree
)
self
.
matrix_A_inv
=
()
self
.
matrix_G_inv
=
()
self
.
matrix_max_inv
=
()
...
...
model_zoo/official/cv/resnet_thor/train.py
浏览文件 @
4e734650
...
...
@@ -95,11 +95,7 @@ if __name__ == '__main__':
context
.
set_context
(
device_id
=
device_id
,
enable_auto_mixed_precision
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
],
"hccl_world_groupsum1"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum2"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum3"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum4"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum5"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
])
init
()
# GPU target
else
:
...
...
tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py
浏览文件 @
4e734650
...
...
@@ -12,150 +12,109 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad_reducer_thor"""
import
mindspore.common.dtype
as
mstype
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
"""grad reducer cell for distributed training"""
from
mindspore.nn.cell
import
Cell
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
ReduceOp
from
mindspore.ops.operations.comm_ops
import
AllReduce
import
mindspore.common.dtype
as
mstype
reduce_opt
=
C
.
MultitypeFuncGraph
(
"reduce_opt"
)
_all_reduce_A
=
AllReduce
()
def
_init_allreduce_operators
(
length
,
split_indices
):
""" initialize allreduce communication operators"""
indices
=
split_indices
[
0
]
fusion
=
split_indices
[
1
]
op_list
=
()
j
=
0
for
i
in
range
(
length
):
if
j
<=
len
(
indices
)
-
1
:
temp
=
indices
[
j
]
else
:
temp
=
length
if
i
>=
temp
:
j
=
j
+
1
fusion
=
fusion
+
1
op
=
AllReduce
(
'sum'
,
GlobalComm
.
WORLD_COMM_GROUP
)
op
.
add_prim_attr
(
'fusion'
,
fusion
)
op_list
=
op_list
+
(
op
,)
return
op_list
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Function"
,
"Tensor"
)
def
_tensors_allreduce_mean
(
mul
,
degree
,
allreduce
,
parameters
):
"""
Apply allreduce on parameters.
def
_init_optimizer_allreduce
(
group
)
:
global
_all_reduce_A
_all_reduce_A
=
AllReduce
(
ReduceOp
.
SUM
,
GlobalComm
.
WORLD_COMM_GROUP
)
_all_reduce_A
.
add_prim_attr
(
'fusion'
,
group
)
Args
:
mul(Primitive): The mul operator for parameters.
degree (int): The mean coefficient.
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.
@
reduce_opt
.
register
(
"Function"
,
"Number"
,
"Tensor"
)
def
_tensors_allreduce_mean
(
mul
,
degree
,
grad
):
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
grad
))
grad
=
_all_reduce_A
(
grad
)
Returns:
Tensor, the parameters after operation.
"""
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
parameters
))
parameters
=
allreduce
(
parameters
)
cast_op
=
P
.
Cast
()
return
mul
(
grad
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
grad
)))
@
reduce_opt
.
register
(
"Bool"
,
"Tensor"
)
def
_tensors_allreduce
(
allreduce_filter
,
grad
):
if
allreduce_filter
:
return
_all_reduce_A
(
grad
)
return
grad
return
mul
(
parameters
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
parameters
)))
_get_datatype
=
C
.
MultitypeFuncGraph
(
"_get_datatype"
)
@
_get_datatype
.
register
(
"Tensor"
)
def
_tensors_get_datatype
(
grad
):
def
_tensors_get_datatype
(
parameters
):
"""
Acquire
gradient
datatype.
Acquire
parameters
datatype.
Args:
grad (Tensor): The gradient tensor
before operation.
parameters (Tensor): The parameters
before operation.
Returns:
mstype, the datatype of
gradient
.
mstype, the datatype of
parameters
.
"""
return
F
.
dtype
(
grad
)
return
F
.
dtype
(
parameters
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"Tensor"
)
def
_tensors_cast_datatype
(
datatype
,
grad
):
def
_tensors_cast_datatype
(
datatype
,
parameters
):
"""
Cast
gradient
to datatype.
Cast
parameters
to datatype.
Args:
datatype (mstype): the destination datatype of
gradient
.
grad (Tensor): The gradient tensor
before operation.
datatype (mstype): the destination datatype of
parameters
.
parameters (Tensor): The parameters
before operation.
Returns:
Tensor, the
gradient tensor
after operation.
Tensor, the
parameters
after operation.
"""
return
F
.
cast
(
grad
,
datatype
)
return
F
.
cast
(
parameters
,
datatype
)
class
DistributedGradReducerThor
(
Cell
):
"""
A distributed optimizer.
Constructs a
gradient
reducer Cell, which applies communication and average operations on
single-process
gradient
values.
Constructs a
parameters
reducer Cell, which applies communication and average operations on
single-process
parameters
values.
Args:
parameter
s (list):
the parameters to be updated.
group (int): the different group to allreduce
.
mean (bool): When mean is true, the mean coefficient (degree) would apply on
gradient
s. Default: False.
parameter
_length (int): length of
the parameters to be updated.
split_indices(tuple): parameter split indices
.
mean (bool): When mean is true, the mean coefficient (degree) would apply on
parameter
s. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
Raises:
ValueError: If degree is not a int or less than 0.
Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
"""
def
__init__
(
self
,
parameter
s
,
group
,
mean
=
True
,
degree
=
None
):
def
__init__
(
self
,
parameter
_length
,
split_indices
,
mean
=
True
,
degree
=
None
):
super
(
DistributedGradReducerThor
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
mul
=
P
.
Mul
()
...
...
@@ -166,20 +125,11 @@ class DistributedGradReducerThor(Cell):
raise
ValueError
(
"Parameter 'degree' in DistributedGradReducer should large than 0 and be int"
)
self
.
degree
=
degree
self
.
mean
=
mean
self
.
allreduce_filter
=
tuple
(
x
.
layerwise_parallel
is
False
for
x
in
parameters
)
_init_optimizer_allreduce
(
group
)
def
construct
(
self
,
grads
):
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
grads
)
if
self
.
mean
:
new_grad
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
grads
)
else
:
new_grad
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
),
self
.
allreduce_filter
,
grads
)
new_grad
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_grad
)
return
new_grad
self
.
op_list
=
_init_allreduce_operators
(
parameter_length
,
split_indices
)
def
construct
(
self
,
parameters
):
datatypes
=
self
.
hyper_map
(
F
.
partial
(
_get_datatype
),
parameters
)
parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
,
mstype
.
float32
),
parameters
)
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
reduce_opt
,
self
.
mul
,
self
.
degree
),
self
.
op_list
,
parameters
)
new_parameters
=
self
.
hyper_map
(
F
.
partial
(
_cast_datatype
),
datatypes
,
new_parameters
)
return
new_parameters
tests/st/networks/models/resnet50/src_thor/thor.py
浏览文件 @
4e734650
...
...
@@ -89,10 +89,11 @@ class THOR(Optimizer):
1.0
]
mean
=
_get_gradients_mean
()
degree
=
_get_device_num
()
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
self
.
parameters
,
2
,
mean
,
degree
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
self
.
parameters
,
5
,
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
self
.
parameters
,
3
,
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
self
.
parameters
,
4
,
mean
,
degree
)
parameter_length
=
len
(
self
.
feature_map
)
self
.
grad_reducer_Amax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
2
),
mean
,
degree
)
self
.
grad_reducer_Gmax
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
4
),
mean
,
degree
)
self
.
grad_reducer_A
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
6
),
mean
,
degree
)
self
.
grad_reducer_G
=
DistributedGradReducerThor
(
parameter_length
,
((
27
,),
8
),
mean
,
degree
)
self
.
matrix_A_inv
=
()
self
.
matrix_G_inv
=
()
self
.
matrix_max_inv
=
()
...
...
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
浏览文件 @
4e734650
...
...
@@ -241,11 +241,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
if
enable_hccl
:
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
,
parameter_broadcast
=
True
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
],
"hccl_world_groupsum1"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum2"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum3"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum4"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
27
],
"hccl_world_groupsum5"
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
107
])
init
()
# network
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录