Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ac958362
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ac958362
编写于
7月 14, 2020
作者:
C
chenhaozhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add global norm in bert
上级
7371cedd
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
69 addition
and
8 deletion
+69
-8
model_zoo/official/nlp/bert/run_pretrain.py
model_zoo/official/nlp/bert/run_pretrain.py
+4
-2
model_zoo/official/nlp/bert/src/bert_for_pre_training.py
model_zoo/official/nlp/bert/src/bert_for_pre_training.py
+13
-4
model_zoo/official/nlp/bert/src/config.py
model_zoo/official/nlp/bert/src/config.py
+2
-2
model_zoo/official/nlp/bert/src/utils.py
model_zoo/official/nlp/bert/src/utils.py
+50
-0
未找到文件。
model_zoo/official/nlp/bert/run_pretrain.py
浏览文件 @
ac958362
...
@@ -178,12 +178,14 @@ def run_pretrain():
...
@@ -178,12 +178,14 @@ def run_pretrain():
if
args_opt
.
accumulation_steps
<=
1
:
if
args_opt
.
accumulation_steps
<=
1
:
net_with_grads
=
BertTrainOneStepWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
net_with_grads
=
BertTrainOneStepWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
scale_update_cell
=
update_cell
,
enable_global_norm
=
cfg
.
enable_global_norm
)
else
:
else
:
accumulation_steps
=
args_opt
.
accumulation_steps
accumulation_steps
=
args_opt
.
accumulation_steps
net_with_grads
=
BertTrainAccumulateStepsWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
net_with_grads
=
BertTrainAccumulateStepsWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
,
scale_update_cell
=
update_cell
,
accumulation_steps
=
accumulation_steps
)
accumulation_steps
=
accumulation_steps
,
enable_global_norm
=
cfg
.
enable_global_norm
)
else
:
else
:
net_with_grads
=
BertTrainOneStepCell
(
net_with_loss
,
optimizer
=
optimizer
)
net_with_grads
=
BertTrainOneStepCell
(
net_with_loss
,
optimizer
=
optimizer
)
...
...
model_zoo/official/nlp/bert/src/bert_for_pre_training.py
浏览文件 @
ac958362
...
@@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size
...
@@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.ops
import
_selected_ops
from
mindspore.ops
import
_selected_ops
from
.bert_model
import
BertModel
from
.bert_model
import
BertModel
from
.utils
import
ClipByGlobalNorm
GRADIENT_CLIP_TYPE
=
1
GRADIENT_CLIP_TYPE
=
1
GRADIENT_CLIP_VALUE
=
1.0
GRADIENT_CLIP_VALUE
=
1.0
...
@@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
...
@@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
optimizer (Optimizer): Optimizer for updating the weights.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
):
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
enable_global_norm
=
False
):
super
(
BertTrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
BertTrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
weights
=
optimizer
.
parameters
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
enable_global_norm
=
enable_global_norm
self
.
grad
=
C
.
GradOperation
(
get_by_list
=
True
,
self
.
grad
=
C
.
GradOperation
(
get_by_list
=
True
,
sens_param
=
True
)
sens_param
=
True
)
self
.
reducer_flag
=
False
self
.
reducer_flag
=
False
...
@@ -419,6 +421,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
...
@@ -419,6 +421,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
# apply grad reducer on grads
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
scaling_sens
*
self
.
degree
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
scaling_sens
*
self
.
degree
),
grads
)
if
self
.
enable_global_norm
:
grads
=
ClipByGlobalNorm
()(
grads
)
else
:
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
self
.
get_status
(
init
)
self
.
get_status
(
init
)
flag_sum
=
self
.
reduce_sum
(
init
,
(
0
,))
flag_sum
=
self
.
reduce_sum
(
init
,
(
0
,))
...
@@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
...
@@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
batch_size * accumulation_steps. Default: 1.
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
accumulation_steps
=
1
):
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
accumulation_steps
=
1
,
enable_global_norm
=
False
):
super
(
BertTrainAccumulateStepsWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
BertTrainAccumulateStepsWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
weights
=
optimizer
.
parameters
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
accumulation_steps
=
accumulation_steps
self
.
accumulation_steps
=
accumulation_steps
self
.
enable_global_norm
=
enable_global_norm
self
.
one
=
Tensor
(
np
.
array
([
1
]).
astype
(
np
.
int32
))
self
.
one
=
Tensor
(
np
.
array
([
1
]).
astype
(
np
.
int32
))
self
.
zero
=
Tensor
(
np
.
array
([
0
]).
astype
(
np
.
int32
))
self
.
zero
=
Tensor
(
np
.
array
([
0
]).
astype
(
np
.
int32
))
self
.
local_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"local_step"
)
self
.
local_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"local_step"
)
...
@@ -582,6 +588,9 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
...
@@ -582,6 +588,9 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
grads
=
self
.
grad_reducer
(
self
.
accu_grads
)
grads
=
self
.
grad_reducer
(
self
.
accu_grads
)
scaling
=
scaling_sens
*
self
.
degree
*
self
.
accumulation_steps
scaling
=
scaling_sens
*
self
.
degree
*
self
.
accumulation_steps
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
scaling
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
scaling
),
grads
)
if
self
.
enable_global_norm
:
grads
=
ClipByGlobalNorm
()(
grad
)
else
:
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
accu_overflow
=
self
.
overflow_reducer
(
accu_overflow
)
accu_overflow
=
self
.
overflow_reducer
(
accu_overflow
)
F
.
control_depend
(
grads
,
accu_overflow
)
F
.
control_depend
(
grads
,
accu_overflow
)
...
...
model_zoo/official/nlp/bert/src/config.py
浏览文件 @
ac958362
...
@@ -24,6 +24,7 @@ cfg = edict({
...
@@ -24,6 +24,7 @@ cfg = edict({
'scale_factor'
:
2
,
'scale_factor'
:
2
,
'scale_window'
:
1000
,
'scale_window'
:
1000
,
'optimizer'
:
'Lamb'
,
'optimizer'
:
'Lamb'
,
'enable_global_norm'
:
False
,
'AdamWeightDecay'
:
edict
({
'AdamWeightDecay'
:
edict
({
'learning_rate'
:
3e-5
,
'learning_rate'
:
3e-5
,
'end_learning_rate'
:
0.0
,
'end_learning_rate'
:
0.0
,
...
@@ -115,6 +116,5 @@ if cfg.bert_network == 'large':
...
@@ -115,6 +116,5 @@ if cfg.bert_network == 'large':
input_mask_from_dataset
=
True
,
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
compute_type
=
mstype
.
float16
enable_fused_layernorm
=
True
)
)
model_zoo/official/nlp/bert/src/utils.py
浏览文件 @
ac958362
...
@@ -23,12 +23,62 @@ import numpy as np
...
@@ -23,12 +23,62 @@ import numpy as np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
composite
as
C
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore.train.callback
import
Callback
from
mindspore.train.callback
import
Callback
from
mindspore.nn.learning_rate_schedule
import
LearningRateSchedule
,
PolynomialDecayLR
,
WarmUpLR
from
mindspore.nn.learning_rate_schedule
import
LearningRateSchedule
,
PolynomialDecayLR
,
WarmUpLR
get_square_sum
=
C
.
MultitypeFuncGraph
(
"get_square_sum"
)
@
get_square_sum
.
register
(
"Tensor"
)
def
_get_square_sum
(
grad
):
norm
=
P
.
ReduceSum
(
False
)(
F
.
square
(
grad
),
())
norm
=
F
.
expand_dims
(
F
.
cast
(
norm
,
mstype
.
float32
),
0
)
return
norm
apply_global_norm
=
C
.
MultitypeFuncGraph
(
"apply_global_norm"
)
@
apply_global_norm
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_apply_global_norm
(
clip_norm
,
global_norm
,
grad
):
grad
=
grad
*
clip_norm
/
global_norm
return
grad
class
GlobalNorm
(
nn
.
Cell
):
"""
Calculate the global norm value of given tensors
"""
def
__init__
(
self
):
super
(
GlobalNorm
,
self
).
__init__
()
self
.
norm
=
nn
.
Norm
()
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
grads
):
square_sum
=
self
.
hyper_map
(
get_square_sum
,
grads
)
global_norms
=
F
.
sqrt
(
F
.
addn
(
square_sum
)
/
F
.
scalar_to_array
(
len
(
square_sum
)))
return
global_norms
class
ClipByGlobalNorm
(
nn
.
Cell
):
"""
Clip grads by global norm
"""
def
__init__
(
self
,
clip_norm
=
1.0
):
super
(
ClipByGlobalNorm
,
self
).
__init__
()
self
.
global_norm
=
GlobalNorm
()
self
.
clip_norm
=
Tensor
([
clip_norm
],
mstype
.
float32
)
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
grads
):
global_norm
=
self
.
global_norm
(
grads
)
cond
=
P
.
GreaterEqual
()(
global_norm
,
self
.
clip_norm
)
global_norm
=
F
.
select
(
cond
,
global_norm
,
self
.
clip_norm
)
grads
=
self
.
hyper_map
(
F
.
partial
(
apply_global_norm
,
self
.
clip_norm
,
global_norm
),
grads
)
return
grads
class
CrossEntropyCalculation
(
nn
.
Cell
):
class
CrossEntropyCalculation
(
nn
.
Cell
):
"""
"""
Cross Entropy loss
Cross Entropy loss
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录