Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
40fc11e9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
40fc11e9
编写于
8月 25, 2020
作者:
S
shibeiji
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Mimic higher batch size by accumulating gradients N times before weight update
上级
5b722a10
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
195 addition
and
7 deletion
+195
-7
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
model_zoo/official/nlp/bert/README.md
model_zoo/official/nlp/bert/README.md
+2
-0
model_zoo/official/nlp/bert/run_pretrain.py
model_zoo/official/nlp/bert/run_pretrain.py
+23
-3
model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini
...ts/ascend_distributed_launcher/hyper_parameter_config.ini
+2
-1
model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh
...ficial/nlp/bert/scripts/run_standalone_pretrain_ascend.sh
+1
-0
model_zoo/official/nlp/bert/src/__init__.py
model_zoo/official/nlp/bert/src/__init__.py
+4
-2
model_zoo/official/nlp/bert/src/bert_for_pre_training.py
model_zoo/official/nlp/bert/src/bert_for_pre_training.py
+161
-0
model_zoo/official/nlp/bert/src/config.py
model_zoo/official/nlp/bert/src/config.py
+1
-1
未找到文件。
mindspore/ops/functional.py
浏览文件 @
40fc11e9
...
@@ -63,6 +63,7 @@ check_bprop = P.CheckBprop()
...
@@ -63,6 +63,7 @@ check_bprop = P.CheckBprop()
equal
=
P
.
Equal
()
equal
=
P
.
Equal
()
not_equal
=
P
.
NotEqual
()
not_equal
=
P
.
NotEqual
()
assign_sub
=
P
.
AssignSub
()
assign_sub
=
P
.
AssignSub
()
assign_add
=
P
.
AssignAdd
()
assign
=
P
.
Assign
()
assign
=
P
.
Assign
()
square
=
P
.
Square
()
square
=
P
.
Square
()
sqrt
=
P
.
Sqrt
()
sqrt
=
P
.
Sqrt
()
...
...
model_zoo/official/nlp/bert/README.md
浏览文件 @
40fc11e9
...
@@ -123,6 +123,7 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
...
@@ -123,6 +123,7 @@ usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_n
[--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET]
[--enable_save_ckpt ENABLE_SAVE_CKPT] [--device_target DEVICE_TARGET]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N]
[--accumulation_steps N]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--load_checkpoint_path LOAD_CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
...
@@ -139,6 +140,7 @@ options:
...
@@ -139,6 +140,7 @@ options:
--do_shuffle enable shuffle: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--data_sink_steps set data sink steps: N, default is 1
--accumulation_steps accumulate gradients N times before weight update: N, default is 1
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
--load_checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
...
...
model_zoo/official/nlp/bert/run_pretrain.py
浏览文件 @
40fc11e9
...
@@ -30,7 +30,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
...
@@ -30,7 +30,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.nn.optim
import
Lamb
,
Momentum
,
AdamWeightDecay
from
mindspore.nn.optim
import
Lamb
,
Momentum
,
AdamWeightDecay
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
src
import
BertNetworkWithLoss
,
BertTrainOneStepCell
,
BertTrainOneStepWithLossScaleCell
from
src
import
BertNetworkWithLoss
,
BertTrainOneStepCell
,
BertTrainOneStepWithLossScaleCell
,
\
BertTrainAccumulateStepsWithLossScaleCell
from
src.dataset
import
create_bert_dataset
from
src.dataset
import
create_bert_dataset
from
src.config
import
cfg
,
bert_net_cfg
from
src.config
import
cfg
,
bert_net_cfg
from
src.utils
import
LossCallBack
,
BertLearningRate
from
src.utils
import
LossCallBack
,
BertLearningRate
...
@@ -51,6 +52,8 @@ def run_pretrain():
...
@@ -51,6 +52,8 @@ def run_pretrain():
parser
.
add_argument
(
"--do_shuffle"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable shuffle for dataset, default is true."
)
parser
.
add_argument
(
"--do_shuffle"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable shuffle for dataset, default is true."
)
parser
.
add_argument
(
"--enable_data_sink"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable data sink, default is true."
)
parser
.
add_argument
(
"--enable_data_sink"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable data sink, default is true."
)
parser
.
add_argument
(
"--data_sink_steps"
,
type
=
int
,
default
=
"1"
,
help
=
"Sink steps for each epoch, default is 1."
)
parser
.
add_argument
(
"--data_sink_steps"
,
type
=
int
,
default
=
"1"
,
help
=
"Sink steps for each epoch, default is 1."
)
parser
.
add_argument
(
"--accumulation_steps"
,
type
=
int
,
default
=
"1"
,
help
=
"Accumulating gradients N times before weight update, default is 1."
)
parser
.
add_argument
(
"--save_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
parser
.
add_argument
(
"--save_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
parser
.
add_argument
(
"--load_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
parser
.
add_argument
(
"--load_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
parser
.
add_argument
(
"--save_checkpoint_steps"
,
type
=
int
,
default
=
1000
,
help
=
"Save checkpoint steps, "
parser
.
add_argument
(
"--save_checkpoint_steps"
,
type
=
int
,
default
=
1000
,
help
=
"Save checkpoint steps, "
...
@@ -98,6 +101,16 @@ def run_pretrain():
...
@@ -98,6 +101,16 @@ def run_pretrain():
logger
.
warning
(
'Gpu only support fp32 temporarily, run with fp32.'
)
logger
.
warning
(
'Gpu only support fp32 temporarily, run with fp32.'
)
bert_net_cfg
.
compute_type
=
mstype
.
float32
bert_net_cfg
.
compute_type
=
mstype
.
float32
if
args_opt
.
accumulation_steps
>
1
:
logger
.
info
(
"accumulation steps: {}"
.
format
(
args_opt
.
accumulation_steps
))
logger
.
info
(
"global batch size: {}"
.
format
(
bert_net_cfg
.
batch_size
*
args_opt
.
accumulation_steps
))
if
args_opt
.
enable_data_sink
==
"true"
:
args_opt
.
data_sink_steps
*=
args_opt
.
accumulation_steps
logger
.
info
(
"data sink steps: {}"
.
format
(
args_opt
.
data_sink_steps
))
if
args_opt
.
enable_save_ckpt
==
"true"
:
args_opt
.
save_checkpoint_steps
*=
args_opt
.
accumulation_steps
logger
.
info
(
"save checkpoint steps: {}"
.
format
(
args_opt
.
save_checkpoint_steps
))
ds
=
create_bert_dataset
(
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
ds
=
create_bert_dataset
(
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
net_with_loss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
net_with_loss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
...
@@ -157,8 +170,15 @@ def run_pretrain():
...
@@ -157,8 +170,15 @@ def run_pretrain():
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
scale_window
=
cfg
.
scale_window
)
net_with_grads
=
BertTrainOneStepWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
if
args_opt
.
accumulation_steps
<=
1
:
net_with_grads
=
BertTrainOneStepWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
else
:
accumulation_steps
=
args_opt
.
accumulation_steps
net_with_grads
=
BertTrainAccumulateStepsWithLossScaleCell
(
net_with_loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
,
accumulation_steps
=
accumulation_steps
)
else
:
else
:
net_with_grads
=
BertTrainOneStepCell
(
net_with_loss
,
optimizer
=
optimizer
)
net_with_grads
=
BertTrainOneStepCell
(
net_with_loss
,
optimizer
=
optimizer
)
...
...
model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini
浏览文件 @
40fc11e9
...
@@ -6,6 +6,7 @@ enable_lossscale=true
...
@@ -6,6 +6,7 @@ enable_lossscale=true
do_shuffle
=
true
do_shuffle
=
true
enable_data_sink
=
true
enable_data_sink
=
true
data_sink_steps
=
100
data_sink_steps
=
100
accumulation_steps
=
1
save_checkpoint_path
=
./checkpoint/
save_checkpoint_path
=
./checkpoint/
save_checkpoint_steps
=
10000
save_checkpoint_steps
=
10000
save_checkpoint_num
=
1
save_checkpoint_num
=
1
\ No newline at end of file
model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh
浏览文件 @
40fc11e9
...
@@ -39,6 +39,7 @@ python ${PROJECT_DIR}/../run_pretrain.py \
...
@@ -39,6 +39,7 @@ python ${PROJECT_DIR}/../run_pretrain.py \
--do_shuffle
=
"true"
\
--do_shuffle
=
"true"
\
--enable_data_sink
=
"true"
\
--enable_data_sink
=
"true"
\
--data_sink_steps
=
1
\
--data_sink_steps
=
1
\
--accumulation_steps
=
1
\
--load_checkpoint_path
=
""
\
--load_checkpoint_path
=
""
\
--save_checkpoint_steps
=
10000
\
--save_checkpoint_steps
=
10000
\
--save_checkpoint_num
=
1
\
--save_checkpoint_num
=
1
\
...
...
model_zoo/official/nlp/bert/src/__init__.py
浏览文件 @
40fc11e9
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
"""Bert Init."""
"""Bert Init."""
from
.bert_for_pre_training
import
BertNetworkWithLoss
,
BertPreTraining
,
\
from
.bert_for_pre_training
import
BertNetworkWithLoss
,
BertPreTraining
,
\
BertPretrainingLoss
,
GetMaskedLMOutput
,
GetNextSentenceOutput
,
\
BertPretrainingLoss
,
GetMaskedLMOutput
,
GetNextSentenceOutput
,
\
BertTrainOneStepCell
,
BertTrainOneStepWithLossScaleCell
BertTrainOneStepCell
,
BertTrainOneStepWithLossScaleCell
,
\
BertTrainAccumulateStepsWithLossScaleCell
from
.bert_model
import
BertAttention
,
BertConfig
,
BertEncoderCell
,
BertModel
,
\
from
.bert_model
import
BertAttention
,
BertConfig
,
BertEncoderCell
,
BertModel
,
\
BertOutput
,
BertSelfAttention
,
BertTransformer
,
EmbeddingLookup
,
\
BertOutput
,
BertSelfAttention
,
BertTransformer
,
EmbeddingLookup
,
\
EmbeddingPostprocessor
,
RelaPosEmbeddingsGenerator
,
RelaPosMatrixGenerator
,
\
EmbeddingPostprocessor
,
RelaPosEmbeddingsGenerator
,
RelaPosMatrixGenerator
,
\
...
@@ -23,7 +24,8 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
...
@@ -23,7 +24,8 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
__all__
=
[
__all__
=
[
"BertNetworkWithLoss"
,
"BertPreTraining"
,
"BertPretrainingLoss"
,
"BertNetworkWithLoss"
,
"BertPreTraining"
,
"BertPretrainingLoss"
,
"GetMaskedLMOutput"
,
"GetNextSentenceOutput"
,
"BertTrainOneStepCell"
,
"BertTrainOneStepWithLossScaleCell"
,
"GetMaskedLMOutput"
,
"GetNextSentenceOutput"
,
"BertTrainOneStepCell"
,
"BertTrainOneStepWithLossScaleCell"
,
"BertTrainAccumulateStepsWithLossScaleCell"
,
"BertAttention"
,
"BertConfig"
,
"BertEncoderCell"
,
"BertModel"
,
"BertOutput"
,
"BertAttention"
,
"BertConfig"
,
"BertEncoderCell"
,
"BertModel"
,
"BertOutput"
,
"BertSelfAttention"
,
"BertTransformer"
,
"EmbeddingLookup"
,
"BertSelfAttention"
,
"BertTransformer"
,
"EmbeddingLookup"
,
"EmbeddingPostprocessor"
,
"RelaPosEmbeddingsGenerator"
,
"EmbeddingPostprocessor"
,
"RelaPosEmbeddingsGenerator"
,
...
...
model_zoo/official/nlp/bert/src/bert_for_pre_training.py
浏览文件 @
40fc11e9
...
@@ -438,3 +438,164 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
...
@@ -438,3 +438,164 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
succ
=
self
.
optimizer
(
grads
)
succ
=
self
.
optimizer
(
grads
)
ret
=
(
loss
,
cond
,
scaling_sens
)
ret
=
(
loss
,
cond
,
scaling_sens
)
return
F
.
depend
(
ret
,
succ
)
return
F
.
depend
(
ret
,
succ
)
cast
=
P
.
Cast
()
update_accu_grads
=
C
.
MultitypeFuncGraph
(
"update_accu_grads"
)
@
update_accu_grads
.
register
(
"Tensor"
,
"Tensor"
)
def
_update_accu_grads
(
accu_grad
,
grad
):
succ
=
True
return
F
.
depend
(
succ
,
F
.
assign_add
(
accu_grad
,
cast
(
grad
,
mstype
.
float32
)))
zeroslike
=
P
.
ZerosLike
()
reset_accu_grads
=
C
.
MultitypeFuncGraph
(
"reset_accu_grads"
)
@
reset_accu_grads
.
register
(
"Tensor"
)
def
_reset_accu_grads
(
accu_grad
):
succ
=
True
return
F
.
depend
(
succ
,
F
.
assign
(
accu_grad
,
zeroslike
(
accu_grad
)))
class
BertTrainAccumulateStepsWithLossScaleCell
(
nn
.
Cell
):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph. To mimic higher batch size, gradients are
accumulated N times before weight update.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
accumulation_steps
=
1
):
super
(
BertTrainAccumulateStepsWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
self
.
accumulation_steps
=
accumulation_steps
self
.
one
=
Tensor
(
np
.
array
([
1
]).
astype
(
np
.
int32
))
self
.
zero
=
Tensor
(
np
.
array
([
0
]).
astype
(
np
.
int32
))
self
.
local_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"local_step"
)
self
.
accu_grads
=
self
.
weights
.
clone
(
prefix
=
"accu_grads"
,
init
=
'zeros'
)
self
.
accu_overflow
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"accu_overflow"
)
self
.
loss
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
float32
),
name
=
"accu_loss"
)
self
.
grad
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
reducer_flag
=
False
self
.
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
if
self
.
parallel_mode
in
[
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
]:
self
.
reducer_flag
=
True
self
.
grad_reducer
=
F
.
identity
self
.
degree
=
1
if
self
.
reducer_flag
:
self
.
degree
=
get_group_size
()
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
False
,
self
.
degree
)
self
.
is_distributed
=
(
self
.
parallel_mode
!=
ParallelMode
.
STAND_ALONE
)
self
.
overflow_reducer
=
F
.
identity
if
self
.
is_distributed
:
self
.
overflow_reducer
=
P
.
AllReduce
()
self
.
cast
=
P
.
Cast
()
self
.
alloc_status
=
P
.
NPUAllocFloatStatus
()
self
.
get_status
=
P
.
NPUGetFloatStatus
()
self
.
clear_before_grad
=
P
.
NPUClearFloatStatus
()
self
.
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
False
)
self
.
base
=
Tensor
(
1
,
mstype
.
float32
)
self
.
less_equal
=
P
.
LessEqual
()
self
.
logical_or
=
P
.
LogicalOr
()
self
.
not_equal
=
P
.
NotEqual
()
self
.
select
=
P
.
Select
()
self
.
reshape
=
P
.
Reshape
()
self
.
hyper_map
=
C
.
HyperMap
()
self
.
loss_scale
=
None
self
.
loss_scaling_manager
=
scale_update_cell
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
@
C
.
add_flags
(
has_effect
=
True
)
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
,
next_sentence_labels
,
masked_lm_positions
,
masked_lm_ids
,
masked_lm_weights
,
sens
=
None
):
"""Defines the computation performed."""
weights
=
self
.
weights
loss
=
self
.
network
(
input_ids
,
input_mask
,
token_type_id
,
next_sentence_labels
,
masked_lm_positions
,
masked_lm_ids
,
masked_lm_weights
)
if
sens
is
None
:
scaling_sens
=
self
.
loss_scale
else
:
scaling_sens
=
sens
# update accumulation parameters
is_accu_step
=
self
.
not_equal
(
self
.
local_step
,
self
.
accumulation_steps
)
self
.
local_step
=
self
.
select
(
is_accu_step
,
self
.
local_step
+
self
.
one
,
self
.
one
)
self
.
loss
=
self
.
select
(
is_accu_step
,
self
.
loss
+
loss
,
loss
)
mean_loss
=
self
.
loss
/
self
.
local_step
is_accu_step
=
self
.
not_equal
(
self
.
local_step
,
self
.
accumulation_steps
)
# alloc status and clear should be right before gradoperation
init
=
self
.
alloc_status
()
self
.
clear_before_grad
(
init
)
grads
=
self
.
grad
(
self
.
network
,
weights
)(
input_ids
,
input_mask
,
token_type_id
,
next_sentence_labels
,
masked_lm_positions
,
masked_lm_ids
,
masked_lm_weights
,
self
.
cast
(
scaling_sens
,
mstype
.
float32
))
accu_succ
=
self
.
hyper_map
(
update_accu_grads
,
self
.
accu_grads
,
grads
)
mean_loss
=
F
.
depend
(
mean_loss
,
accu_succ
)
self
.
get_status
(
init
)
flag_sum
=
self
.
reduce_sum
(
init
,
(
0
,))
overflow
=
self
.
less_equal
(
self
.
base
,
flag_sum
)
overflow
=
self
.
logical_or
(
self
.
not_equal
(
self
.
accu_overflow
,
self
.
zero
),
overflow
)
accu_overflow
=
self
.
select
(
overflow
,
self
.
one
,
self
.
zero
)
self
.
accu_overflow
=
self
.
select
(
is_accu_step
,
accu_overflow
,
self
.
zero
)
if
is_accu_step
:
succ
=
False
else
:
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
self
.
accu_grads
)
scaling
=
scaling_sens
*
self
.
degree
*
self
.
accumulation_steps
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
scaling
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
accu_overflow
=
self
.
overflow_reducer
(
accu_overflow
)
F
.
control_depend
(
grads
,
accu_overflow
)
overflow
=
self
.
less_equal
(
self
.
base
,
accu_overflow
)
accu_succ
=
self
.
hyper_map
(
reset_accu_grads
,
self
.
accu_grads
)
overflow
=
F
.
depend
(
overflow
,
accu_succ
)
overflow
=
self
.
reshape
(
overflow
,
(()))
if
sens
is
None
:
overflow
=
self
.
loss_scaling_manager
(
self
.
loss_scale
,
overflow
)
if
overflow
:
succ
=
False
else
:
succ
=
self
.
optimizer
(
grads
)
ret
=
(
mean_loss
,
overflow
,
scaling_sens
)
return
F
.
depend
(
ret
,
succ
)
model_zoo/official/nlp/bert/src/config.py
浏览文件 @
40fc11e9
...
@@ -50,7 +50,7 @@ cfg = edict({
...
@@ -50,7 +50,7 @@ cfg = edict({
'''
'''
Including two kinds of network:
\
Including two kinds of network:
\
base: Goole BERT-base(the base version of BERT model).
base: Goo
g
le BERT-base(the base version of BERT model).
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of
\
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of
\
Functional Relative Posetional Encoding as an effective positional encoding scheme).
Functional Relative Posetional Encoding as an effective positional encoding scheme).
'''
'''
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录