Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
book
提交
e15cfa24
B
book
项目概览
MindSpore
/
book
通知
3
Star
1
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
e15cfa24
编写于
7月 22, 2020
作者:
P
panfengfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix getdatasetsize changes
上级
0f34238a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
90 addition
and
60 deletion
+90
-60
chapter03/lenet/main.py
chapter03/lenet/main.py
+1
-1
chapter04/alexnet/main.py
chapter04/alexnet/main.py
+1
-1
chapter05/resnet/resnet_cifar.py
chapter05/resnet/resnet_cifar.py
+1
-1
chapter06/lstm/train.py
chapter06/lstm/train.py
+1
-1
chapter07/run_pretrain.py
chapter07/run_pretrain.py
+45
-29
chapter07/src/bert_model.py
chapter07/src/bert_model.py
+8
-12
chapter07/src/config.py
chapter07/src/config.py
+4
-2
chapter07/src/dataset.py
chapter07/src/dataset.py
+5
-13
chapter07/src/utils.py
chapter07/src/utils.py
+24
-0
未找到文件。
chapter03/lenet/main.py
浏览文件 @
e15cfa24
...
...
@@ -92,7 +92,7 @@ if __name__ == "__main__":
network
=
LeNet5
(
cfg
.
num_classes
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
repeat_size
=
cfg
.
epoch_size
repeat_size
=
1
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
...
...
chapter04/alexnet/main.py
浏览文件 @
e15cfa24
...
...
@@ -84,7 +84,7 @@ if __name__ == "__main__":
network
=
AlexNet
(
cfg
.
num_classes
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
repeat_size
=
cfg
.
epoch_size
repeat_size
=
1
# when batch_size=32, steps is 1562
lr
=
Tensor
(
get_lr
(
0
,
cfg
.
learning_rate
,
cfg
.
epoch_size
,
1562
))
opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
lr
,
cfg
.
momentum
)
...
...
chapter05/resnet/resnet_cifar.py
浏览文件 @
e15cfa24
...
...
@@ -125,7 +125,7 @@ if __name__ == '__main__':
model
=
Model
(
net
,
loss_fn
=
ls
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
if
args_opt
.
do_train
:
dataset
=
create_dataset
(
epoch_size
)
dataset
=
create_dataset
()
batch_num
=
dataset
.
get_dataset_size
()
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batch_num
,
keep_checkpoint_max
=
10
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"train_resnet_cifar10"
,
directory
=
"./"
,
config
=
config_ck
)
...
...
chapter06/lstm/train.py
浏览文件 @
e15cfa24
...
...
@@ -77,7 +77,7 @@ if __name__ == '__main__':
model
=
Model
(
network
,
loss
,
opt
,
{
'acc'
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
ds_train
=
lstm_create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
cfg
.
num_epochs
)
ds_train
=
lstm_create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
1
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"lstm"
,
directory
=
args
.
ckpt_path
,
config
=
config_ck
)
...
...
chapter07/run_pretrain.py
浏览文件 @
e15cfa24
...
...
@@ -28,12 +28,12 @@ from mindspore.train.parallel_utils import ParallelMode
from
mindspore.nn.wrap.loss_scale
import
DynamicLossScaleUpdateCell
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
TimeMonitor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.nn.optim
import
Lamb
,
Momentum
,
AdamWeightDecay
DynamicLR
from
mindspore.nn.optim
import
Lamb
,
Momentum
,
AdamWeightDecay
from
mindspore
import
log
as
logger
from
src
import
BertNetworkWithLoss
,
BertTrainOneStepCell
,
BertTrainOneStepWithLossScaleCell
from
src.dataset
import
create_bert_dataset
from
src.config
import
cfg
,
bert_net_cfg
from
src.utils
import
LossCallBack
from
src.utils
import
LossCallBack
,
BertLearningRate
_current_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
...
...
@@ -64,7 +64,6 @@ def run_pretrain():
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
reserve_class_name_in_scope
=
False
)
context
.
set_context
(
variable_memory_max_size
=
"30GB"
)
ckpt_save_dir
=
args_opt
.
save_checkpoint_path
if
args_opt
.
distribute
==
"true"
:
if
args_opt
.
device_target
==
'Ascend'
:
...
...
@@ -99,35 +98,49 @@ def run_pretrain():
logger
.
warning
(
'Gpu only support fp32 temporarily, run with fp32.'
)
bert_net_cfg
.
compute_type
=
mstype
.
float32
ds
=
create_bert_dataset
(
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
net_with_loss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
ds
,
new_repeat_count
=
create_bert_dataset
(
args_opt
.
epoch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
enable_data_sink
,
args_opt
.
data_sink_steps
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
new_repeat_count
=
args_opt
.
epoch_size
*
ds
.
get_dataset_size
()
//
args_opt
.
data_sink_steps
if
args_opt
.
train_steps
>
0
:
new_repeat_count
=
min
(
new_repeat_count
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
netwithloss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
else
:
args_opt
.
train_steps
=
args_opt
.
epoch_size
*
ds
.
get_dataset_size
()
if
cfg
.
optimizer
==
'Lamb'
:
optimizer
=
Lamb
(
netwithloss
.
trainable_params
(),
decay_steps
=
ds
.
get_dataset_size
()
*
new_repeat_count
,
start_learning_rate
=
cfg
.
Lamb
.
start_learning_rate
,
end_learning_rate
=
cfg
.
Lamb
.
end_learning_rate
,
power
=
cfg
.
Lamb
.
power
,
warmup_steps
=
cfg
.
Lamb
.
warmup_steps
,
weight_decay
=
cfg
.
Lamb
.
weight_decay
,
eps
=
cfg
.
Lamb
.
eps
)
lr_schedule
=
BertLearningRate
(
learning_rate
=
cfg
.
Lamb
.
learning_rate
,
end_learning_rate
=
cfg
.
Lamb
.
end_learning_rate
,
warmup_steps
=
cfg
.
Lamb
.
warmup_steps
,
decay_steps
=
args_opt
.
train_steps
,
power
=
cfg
.
Lamb
.
power
)
params
=
net_with_loss
.
trainable_params
()
decay_params
=
list
(
filter
(
cfg
.
Lamb
.
decay_filter
,
params
))
other_params
=
list
(
filter
(
lambda
x
:
x
not
in
decay_params
,
params
))
group_params
=
[{
'params'
:
decay_params
,
'weight_decay'
:
cfg
.
Lamb
.
weight_decay
},
{
'params'
:
other_params
},
{
'order_params'
:
params
}]
optimizer
=
Lamb
(
group_params
,
learning_rate
=
lr_schedule
,
eps
=
cfg
.
Lamb
.
eps
)
elif
cfg
.
optimizer
==
'Momentum'
:
optimizer
=
Momentum
(
net
with
loss
.
trainable_params
(),
learning_rate
=
cfg
.
Momentum
.
learning_rate
,
optimizer
=
Momentum
(
net
_with_
loss
.
trainable_params
(),
learning_rate
=
cfg
.
Momentum
.
learning_rate
,
momentum
=
cfg
.
Momentum
.
momentum
)
elif
cfg
.
optimizer
==
'AdamWeightDecayDynamicLR'
:
optimizer
=
AdamWeightDecayDynamicLR
(
netwithloss
.
trainable_params
(),
decay_steps
=
ds
.
get_dataset_size
()
*
new_repeat_count
,
learning_rate
=
cfg
.
AdamWeightDecayDynamicLR
.
learning_rate
,
end_learning_rate
=
cfg
.
AdamWeightDecayDynamicLR
.
end_learning_rate
,
power
=
cfg
.
AdamWeightDecayDynamicLR
.
power
,
weight_decay
=
cfg
.
AdamWeightDecayDynamicLR
.
weight_decay
,
eps
=
cfg
.
AdamWeightDecayDynamicLR
.
eps
,
warmup_steps
=
cfg
.
AdamWeightDecayDynamicLR
.
warmup_steps
)
elif
cfg
.
optimizer
==
'AdamWeightDecay'
:
lr_schedule
=
BertLearningRate
(
learning_rate
=
cfg
.
AdamWeightDecay
.
learning_rate
,
end_learning_rate
=
cfg
.
AdamWeightDecay
.
end_learning_rate
,
warmup_steps
=
cfg
.
AdamWeightDecay
.
warmup_steps
,
decay_steps
=
args_opt
.
train_steps
,
power
=
cfg
.
AdamWeightDecay
.
power
)
params
=
net_with_loss
.
trainable_params
()
decay_params
=
list
(
filter
(
cfg
.
AdamWeightDecay
.
decay_filter
,
params
))
other_params
=
list
(
filter
(
lambda
x
:
x
not
in
decay_params
,
params
))
group_params
=
[{
'params'
:
decay_params
,
'weight_decay'
:
cfg
.
AdamWeightDecay
.
weight_decay
},
{
'params'
:
other_params
,
'weight_decay'
:
0.0
},
{
'order_params'
:
params
}]
optimizer
=
AdamWeightDecay
(
group_params
,
learning_rate
=
lr_schedule
,
eps
=
cfg
.
AdamWeightDecay
.
eps
)
else
:
raise
ValueError
(
"Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay
DynamicLR
]"
.
raise
ValueError
(
"Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
.
format
(
cfg
.
optimizer
))
callback
=
[
TimeMonitor
(
ds
.
get_dataset_size
()
),
LossCallBack
()]
callback
=
[
TimeMonitor
(
args_opt
.
data_sink_steps
),
LossCallBack
()]
if
args_opt
.
enable_save_ckpt
==
"true"
:
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
args_opt
.
save_checkpoint_steps
,
keep_checkpoint_max
=
args_opt
.
save_checkpoint_num
)
...
...
@@ -136,19 +149,22 @@ def run_pretrain():
if
args_opt
.
load_checkpoint_path
:
param_dict
=
load_checkpoint
(
args_opt
.
load_checkpoint_path
)
load_param_into_net
(
net
with
loss
,
param_dict
)
load_param_into_net
(
net
_with_
loss
,
param_dict
)
if
args_opt
.
enable_lossscale
==
"true"
:
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
net
withgrads
=
BertTrainOneStepWithLossScaleCell
(
netwith
loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
net
_with_grads
=
BertTrainOneStepWithLossScaleCell
(
net_with_
loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
else
:
netwithgrads
=
BertTrainOneStepCell
(
netwithloss
,
optimizer
=
optimizer
)
net_with_grads
=
BertTrainOneStepCell
(
net_with_loss
,
optimizer
=
optimizer
)
model
=
Model
(
net_with_grads
)
model
.
train
(
new_repeat_count
,
ds
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
"true"
),
sink_size
=
args_opt
.
data_sink_steps
)
model
=
Model
(
netwithgrads
)
model
.
train
(
new_repeat_count
,
ds
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
"true"
))
if
__name__
==
'__main__'
:
numpy
.
random
.
seed
(
0
)
run_pretrain
()
chapter07/src/bert_model.py
浏览文件 @
e15cfa24
...
...
@@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
def
__init__
(
self
,
length
,
max_relative_position
):
super
(
RelaPosMatrixGenerator
,
self
).
__init__
()
self
.
_length
=
length
self
.
_max_relative_position
=
Tensor
(
max_relative_position
,
dtype
=
mstype
.
int32
)
self
.
_min_relative_position
=
Tensor
(
-
max_relative_position
,
dtype
=
mstype
.
int32
)
self
.
_max_relative_position
=
max_relative_position
self
.
_min_relative_position
=
-
max_relative_position
self
.
range_length
=
-
length
+
1
self
.
tile
=
P
.
Tile
()
...
...
@@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
self
.
relative_positions_matrix
=
RelaPosMatrixGenerator
(
length
=
length
,
max_relative_position
=
max_relative_position
)
self
.
reshape
=
P
.
Reshape
()
self
.
one_hot
=
P
.
OneHot
()
self
.
on_value
=
Tensor
(
1.0
,
mstype
.
float32
)
self
.
off_value
=
Tensor
(
0.0
,
mstype
.
float32
)
self
.
one_hot
=
nn
.
OneHot
(
depth
=
self
.
vocab_size
)
self
.
shape
=
P
.
Shape
()
self
.
gather
=
P
.
GatherV2
()
# index_select
self
.
matmul
=
P
.
BatchMatMul
()
...
...
@@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
if
self
.
use_one_hot_embeddings
:
flat_relative_positions_matrix
=
self
.
reshape
(
relative_positions_matrix_out
,
(
-
1
,))
one_hot_relative_positions_matrix
=
self
.
one_hot
(
flat_relative_positions_matrix
,
self
.
vocab_size
,
self
.
on_value
,
self
.
off_value
)
flat_relative_positions_matrix
)
embeddings
=
self
.
matmul
(
one_hot_relative_positions_matrix
,
self
.
embeddings_table
)
my_shape
=
self
.
shape
(
relative_positions_matrix_out
)
+
(
self
.
depth
,)
embeddings
=
self
.
reshape
(
embeddings
,
my_shape
)
...
...
@@ -372,11 +370,9 @@ class SaturateCast(nn.Cell):
def
__init__
(
self
,
src_type
=
mstype
.
float32
,
dst_type
=
mstype
.
float32
):
super
(
SaturateCast
,
self
).
__init__
()
np_type
=
mstype
.
dtype_to_nptype
(
dst_type
)
min_type
=
np
.
finfo
(
np_type
).
min
max_type
=
np
.
finfo
(
np_type
).
max
self
.
tensor_min_type
=
Tensor
([
min_type
],
dtype
=
src_type
)
self
.
tensor_max_type
=
Tensor
([
max_type
],
dtype
=
src_type
)
self
.
tensor_min_type
=
float
(
np
.
finfo
(
np_type
).
min
)
self
.
tensor_max_type
=
float
(
np
.
finfo
(
np_type
).
max
)
self
.
min_op
=
P
.
Minimum
()
self
.
max_op
=
P
.
Maximum
()
...
...
@@ -442,7 +438,7 @@ class BertAttention(nn.Cell):
self
.
has_attention_mask
=
has_attention_mask
self
.
use_relative_positions
=
use_relative_positions
self
.
scores_mul
=
Tensor
([
1.0
/
math
.
sqrt
(
float
(
self
.
size_per_head
))],
dtype
=
compute_type
)
self
.
scores_mul
=
1.0
/
math
.
sqrt
(
float
(
self
.
size_per_head
)
)
self
.
reshape
=
P
.
Reshape
()
self
.
shape_from_2d
=
(
-
1
,
from_tensor_width
)
self
.
shape_to_2d
=
(
-
1
,
to_tensor_width
)
...
...
@@ -471,7 +467,7 @@ class BertAttention(nn.Cell):
self
.
trans_shape
=
(
0
,
2
,
1
,
3
)
self
.
trans_shape_relative
=
(
2
,
0
,
1
,
3
)
self
.
trans_shape_position
=
(
1
,
2
,
0
,
3
)
self
.
multiply_data
=
Tensor
([
-
10000.0
,],
dtype
=
compute_type
)
self
.
multiply_data
=
-
10000.0
self
.
batch_num
=
batch_size
*
num_attention_heads
self
.
matmul
=
P
.
BatchMatMul
()
...
...
chapter07/src/config.py
浏览文件 @
e15cfa24
...
...
@@ -24,20 +24,22 @@ cfg = edict({
'scale_factor'
:
2
,
'scale_window'
:
1000
,
'optimizer'
:
'Lamb'
,
'AdamWeightDecay
DynamicLR
'
:
edict
({
'AdamWeightDecay'
:
edict
({
'learning_rate'
:
3e-5
,
'end_learning_rate'
:
1e-10
,
'power'
:
5.0
,
'weight_decay'
:
1e-5
,
'decay_filter'
:
lambda
x
:
'layernorm'
not
in
x
.
name
.
lower
()
and
'bias'
not
in
x
.
name
.
lower
(),
'eps'
:
1e-6
,
'warmup_steps'
:
10000
,
}),
'Lamb'
:
edict
({
'
start_
learning_rate'
:
3e-5
,
'learning_rate'
:
3e-5
,
'end_learning_rate'
:
1e-10
,
'power'
:
10.0
,
'warmup_steps'
:
10000
,
'weight_decay'
:
0.01
,
'decay_filter'
:
lambda
x
:
'layernorm'
not
in
x
.
name
.
lower
()
and
'bias'
not
in
x
.
name
.
lower
(),
'eps'
:
1e-6
,
}),
'Momentum'
:
edict
({
...
...
chapter07/src/dataset.py
浏览文件 @
e15cfa24
...
...
@@ -23,11 +23,9 @@ from mindspore import log as logger
from
.config
import
bert_net_cfg
def
create_bert_dataset
(
epoch_size
=
1
,
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"true"
,
enable_data_sink
=
"true"
,
data_sink_steps
=
1
,
data_dir
=
None
,
schema_dir
=
None
):
def
create_bert_dataset
(
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"true"
,
data_dir
=
None
,
schema_dir
=
None
):
"""create train dataset"""
# apply repeat operations
repeat_count
=
epoch_size
files
=
os
.
listdir
(
data_dir
)
data_files
=
[]
for
file_name
in
files
:
...
...
@@ -36,15 +34,10 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds
=
de
.
TFRecordDataset
(
data_files
,
schema_dir
if
schema_dir
!=
""
else
None
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"next_sentence_labels"
,
"masked_lm_positions"
,
"masked_lm_ids"
,
"masked_lm_weights"
],
shuffle
=
(
do_shuffle
==
"true"
),
num_shards
=
device_num
,
shard_id
=
rank
,
shard_equal_rows
=
True
)
shuffle
=
de
.
Shuffle
.
FILES
if
do_shuffle
==
"true"
else
False
,
num_shards
=
device_num
,
shard_id
=
rank
,
shard_equal_rows
=
True
)
ori_dataset_size
=
ds
.
get_dataset_size
()
print
(
'origin dataset size: '
,
ori_dataset_size
)
new_size
=
ori_dataset_size
if
enable_data_sink
==
"true"
:
new_size
=
data_sink_steps
*
bert_net_cfg
.
batch_size
ds
.
set_dataset_size
(
new_size
)
new_repeat_count
=
int
(
repeat_count
*
ori_dataset_size
//
ds
.
get_dataset_size
())
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_positions"
,
operations
=
type_cast_op
)
...
...
@@ -54,10 +47,9 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds
=
ds
.
map
(
input_columns
=
"input_ids"
,
operations
=
type_cast_op
)
# apply batch operations
ds
=
ds
.
batch
(
bert_net_cfg
.
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
max
(
new_repeat_count
,
repeat_count
))
logger
.
info
(
"data size: {}"
.
format
(
ds
.
get_dataset_size
()))
logger
.
info
(
"repeatcount: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
,
new_repeat_count
logger
.
info
(
"repeat
count: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
def
create_ner_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
assessment_method
=
"accuracy"
,
...
...
chapter07/src/utils.py
浏览文件 @
e15cfa24
...
...
@@ -18,11 +18,13 @@ Functional Cells used in Bert finetune and evaluation.
"""
import
os
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore.common.tensor
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.train.callback
import
Callback
from
mindspore.nn.learning_rate_schedule
import
LearningRateSchedule
,
PolynomialDecayLR
,
WarmUpLR
class
CrossEntropyCalculation
(
nn
.
Cell
):
...
...
@@ -123,3 +125,25 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre
max_num
=
int
(
num
)
load_finetune_checkpoint_path
=
os
.
path
.
join
(
load_finetune_checkpoint_dir
,
filename
)
return
load_finetune_checkpoint_path
class
BertLearningRate
(
LearningRateSchedule
):
"""
Warmup-decay learning rate for Bert network.
"""
def
__init__
(
self
,
learning_rate
,
end_learning_rate
,
warmup_steps
,
decay_steps
,
power
):
super
(
BertLearningRate
,
self
).
__init__
()
self
.
warmup_lr
=
WarmUpLR
(
learning_rate
,
warmup_steps
)
self
.
decay_lr
=
PolynomialDecayLR
(
learning_rate
,
end_learning_rate
,
decay_steps
,
power
)
self
.
warmup_steps
=
Tensor
(
np
.
array
([
warmup_steps
]).
astype
(
np
.
float32
))
self
.
greater
=
P
.
Greater
()
self
.
one
=
Tensor
(
np
.
array
([
1.0
]).
astype
(
np
.
float32
))
self
.
cast
=
P
.
Cast
()
def
construct
(
self
,
global_step
):
is_warmup
=
self
.
cast
(
self
.
greater
(
self
.
warmup_steps
,
global_step
),
mstype
.
float32
)
warmup_lr
=
self
.
warmup_lr
(
global_step
)
decay_lr
=
self
.
decay_lr
(
global_step
)
lr
=
(
self
.
one
-
is_warmup
)
*
decay_lr
+
is_warmup
*
warmup_lr
return
lr
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录