Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
7114c29e
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7114c29e
编写于
4月 27, 2020
作者:
X
xyzhou-puck
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update bert and text.py
上级
1f6a3af9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
309 addition
and
297 deletion
+309
-297
examples/bert/bert_classifier.py
examples/bert/bert_classifier.py
+6
-7
examples/bert/run_classifier_single_gpu.sh
examples/bert/run_classifier_single_gpu.sh
+1
-1
examples/bert_leveldb/bert_classifier.py
examples/bert_leveldb/bert_classifier.py
+9
-10
examples/bert_leveldb/run_classifier_multi_gpu.sh
examples/bert_leveldb/run_classifier_multi_gpu.sh
+1
-1
examples/bert_leveldb/run_classifier_single_gpu.sh
examples/bert_leveldb/run_classifier_single_gpu.sh
+1
-1
hapi/text/bert/__init__.py
hapi/text/bert/__init__.py
+3
-1
hapi/text/bert/bert.py
hapi/text/bert/bert.py
+2
-2
hapi/text/bert/dygraph_optimization.py
hapi/text/bert/dygraph_optimization.py
+182
-0
hapi/text/bert/optimization.py
hapi/text/bert/optimization.py
+31
-168
hapi/text/bert/static_optimization.py
hapi/text/bert/static_optimization.py
+70
-104
hapi/text/text.py
hapi/text/text.py
+3
-2
未找到文件。
examples/bert/bert_classifier.py
浏览文件 @
7114c29e
...
@@ -18,10 +18,10 @@ from hapi.metrics import Accuracy
...
@@ -18,10 +18,10 @@ from hapi.metrics import Accuracy
from
hapi.configure
import
Config
from
hapi.configure
import
Config
from
hapi.text.bert
import
BertEncoder
from
hapi.text.bert
import
BertEncoder
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
hapi.model
import
set_device
,
Model
,
Input
from
hapi.loss
import
SoftmaxWithCrossEntropy
from
hapi.loss
import
SoftmaxWithCrossEntropy
from
hapi.model
import
set_device
,
Model
,
Input
import
hapi.text.tokenizer.tokenization
as
tokenization
import
hapi.text.tokenizer.tokenization
as
tokenization
from
hapi.text.bert
import
Optimizer
,
BertConfig
,
BertDataLoader
,
BertInputExample
from
hapi.text.bert
import
BertConfig
,
BertDataLoader
,
BertInputExample
,
make_optimizer
class
ClsModelLayer
(
Model
):
class
ClsModelLayer
(
Model
):
...
@@ -128,7 +128,7 @@ def main():
...
@@ -128,7 +128,7 @@ def main():
[
None
,
None
],
'int64'
,
name
=
'src_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'src_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'pos_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'pos_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'sent_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'sent_ids'
),
Input
(
[
None
,
None
],
'float32'
,
name
=
'input_mask'
)
[
None
,
None
,
1
],
'float32'
,
name
=
'input_mask'
)
]
]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
...
@@ -139,13 +139,13 @@ def main():
...
@@ -139,13 +139,13 @@ def main():
len
([
"contradiction"
,
"entailment"
,
"neutral"
]),
len
([
"contradiction"
,
"entailment"
,
"neutral"
]),
return_pooled_out
=
True
)
return_pooled_out
=
True
)
optimizer
=
O
ptimizer
(
optimizer
=
make_o
ptimizer
(
warmup_steps
=
warmup_steps
,
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
num_train_steps
=
max_train_steps
,
learning_rate
=
config
.
learning_rate
,
learning_rate
=
config
.
learning_rate
,
model_cls
=
cls_model
,
weight_decay
=
config
.
weight_decay
,
weight_decay
=
config
.
weight_decay
,
scheduler
=
config
.
lr_scheduler
,
scheduler
=
config
.
lr_scheduler
,
model
=
cls_model
,
loss_scaling
=
config
.
loss_scaling
,
loss_scaling
=
config
.
loss_scaling
,
parameter_list
=
cls_model
.
parameters
())
parameter_list
=
cls_model
.
parameters
())
...
@@ -157,8 +157,7 @@ def main():
...
@@ -157,8 +157,7 @@ def main():
labels
,
labels
,
device
=
device
)
device
=
device
)
cls_model
.
bert_layer
.
init_parameters
(
cls_model
.
bert_layer
.
load
(
"./bert_small"
,
reset_optimizer
=
True
)
config
.
init_pretraining_params
,
verbose
=
config
.
verbose
)
# do train
# do train
cls_model
.
fit
(
train_data
=
train_dataloader
.
dataloader
,
cls_model
.
fit
(
train_data
=
train_dataloader
.
dataloader
,
...
...
examples/bert/run_classifier_single_gpu.sh
浏览文件 @
7114c29e
...
@@ -4,7 +4,7 @@ TASK_NAME='MNLI'
...
@@ -4,7 +4,7 @@ TASK_NAME='MNLI'
DATA_PATH
=
"./data/glue_data/MNLI/"
DATA_PATH
=
"./data/glue_data/MNLI/"
CKPT_PATH
=
"./data/saved_model/mnli_models"
CKPT_PATH
=
"./data/saved_model/mnli_models"
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
1
# start fine-tuning
# start fine-tuning
python3.7 bert_classifier.py
\
python3.7 bert_classifier.py
\
...
...
examples/bert_leveldb/bert_classifier.py
浏览文件 @
7114c29e
...
@@ -18,10 +18,10 @@ from hapi.metrics import Accuracy
...
@@ -18,10 +18,10 @@ from hapi.metrics import Accuracy
from
hapi.configure
import
Config
from
hapi.configure
import
Config
from
hapi.text.bert
import
BertEncoder
from
hapi.text.bert
import
BertEncoder
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
hapi.model
import
set_device
,
Model
,
Input
from
hapi.loss
import
SoftmaxWithCrossEntropy
from
hapi.loss
import
SoftmaxWithCrossEntropy
from
hapi.model
import
set_device
,
Model
,
Input
import
hapi.text.tokenizer.tokenization
as
tokenization
import
hapi.text.tokenizer.tokenization
as
tokenization
from
hapi.text.bert
import
Optimizer
,
BertConfig
,
BertDataLoader
,
BertInputExample
from
hapi.text.bert
import
BertConfig
,
BertDataLoader
,
BertInputExample
,
make_optimizer
class
ClsModelLayer
(
Model
):
class
ClsModelLayer
(
Model
):
...
@@ -99,12 +99,12 @@ def main():
...
@@ -99,12 +99,12 @@ def main():
train_dataloader
=
BertDataLoader
(
train_dataloader
=
BertDataLoader
(
"./data/glue_data/MNLI/train.tsv"
,
"./data/glue_data/MNLI/train.tsv"
,
tokenizer
,
[
"contradiction"
,
"entailment"
,
"neutral"
],
tokenizer
,
[
"contradiction"
,
"entailment"
,
"neutral"
],
max_seq_length
=
config
.
max_seq_len
,
max_seq_length
=
config
.
max_seq_len
,
batch_size
=
config
.
batch_size
,
batch_size
=
config
.
batch_size
,
line_processor
=
mnli_line_processor
,
line_processor
=
mnli_line_processor
,
mode
=
"leveldb"
,
mode
=
"leveldb"
,
)
phase
=
"train"
)
test_dataloader
=
BertDataLoader
(
test_dataloader
=
BertDataLoader
(
"./data/glue_data/MNLI/dev_matched.tsv"
,
"./data/glue_data/MNLI/dev_matched.tsv"
,
...
@@ -130,7 +130,7 @@ def main():
...
@@ -130,7 +130,7 @@ def main():
[
None
,
None
],
'int64'
,
name
=
'src_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'src_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'pos_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'pos_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'sent_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'sent_ids'
),
Input
(
[
None
,
None
],
'float32'
,
name
=
'input_mask'
)
[
None
,
None
,
1
],
'float32'
,
name
=
'input_mask'
)
]
]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
...
@@ -141,13 +141,13 @@ def main():
...
@@ -141,13 +141,13 @@ def main():
len
([
"contradiction"
,
"entailment"
,
"neutral"
]),
len
([
"contradiction"
,
"entailment"
,
"neutral"
]),
return_pooled_out
=
True
)
return_pooled_out
=
True
)
optimizer
=
O
ptimizer
(
optimizer
=
make_o
ptimizer
(
warmup_steps
=
warmup_steps
,
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
num_train_steps
=
max_train_steps
,
learning_rate
=
config
.
learning_rate
,
learning_rate
=
config
.
learning_rate
,
model_cls
=
cls_model
,
weight_decay
=
config
.
weight_decay
,
weight_decay
=
config
.
weight_decay
,
scheduler
=
config
.
lr_scheduler
,
scheduler
=
config
.
lr_scheduler
,
model
=
cls_model
,
loss_scaling
=
config
.
loss_scaling
,
loss_scaling
=
config
.
loss_scaling
,
parameter_list
=
cls_model
.
parameters
())
parameter_list
=
cls_model
.
parameters
())
...
@@ -159,8 +159,7 @@ def main():
...
@@ -159,8 +159,7 @@ def main():
labels
,
labels
,
device
=
device
)
device
=
device
)
cls_model
.
bert_layer
.
init_parameters
(
cls_model
.
bert_layer
.
load
(
"./bert_small"
,
reset_optimizer
=
True
)
config
.
init_pretraining_params
,
verbose
=
config
.
verbose
)
# do train
# do train
cls_model
.
fit
(
train_data
=
train_dataloader
.
dataloader
,
cls_model
.
fit
(
train_data
=
train_dataloader
.
dataloader
,
...
...
examples/bert_leveldb/run_classifier_multi_gpu.sh
浏览文件 @
7114c29e
...
@@ -5,7 +5,7 @@ DATA_PATH="./data/glue_data/MNLI/"
...
@@ -5,7 +5,7 @@ DATA_PATH="./data/glue_data/MNLI/"
CKPT_PATH
=
"./data/saved_model/mnli_models"
CKPT_PATH
=
"./data/saved_model/mnli_models"
# start fine-tuning
# start fine-tuning
python3.7
-m
paddle.distributed.launch
--started_port
8899
--selected_gpus
=
0,
1,2,3 bert_classifier.py
\
python3.7
-m
paddle.distributed.launch
--started_port
8899
--selected_gpus
=
1,2,3 bert_classifier.py
\
--use_cuda
true
\
--use_cuda
true
\
--do_train
true
\
--do_train
true
\
--do_test
true
\
--do_test
true
\
...
...
examples/bert_leveldb/run_classifier_single_gpu.sh
浏览文件 @
7114c29e
...
@@ -4,7 +4,7 @@ TASK_NAME='MNLI'
...
@@ -4,7 +4,7 @@ TASK_NAME='MNLI'
DATA_PATH
=
"./data/glue_data/MNLI/"
DATA_PATH
=
"./data/glue_data/MNLI/"
CKPT_PATH
=
"./data/saved_model/mnli_models"
CKPT_PATH
=
"./data/saved_model/mnli_models"
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
1
# start fine-tuning
# start fine-tuning
python3.7 bert_classifier.py
\
python3.7 bert_classifier.py
\
...
...
hapi/text/bert/__init__.py
浏览文件 @
7114c29e
...
@@ -13,7 +13,9 @@
...
@@ -13,7 +13,9 @@
# limitations under the License.
# limitations under the License.
from
hapi.text.bert.bert
import
BertConfig
as
BertConfig
from
hapi.text.bert.bert
import
BertConfig
as
BertConfig
from
hapi.text.bert.optimization
import
Optimizer
as
Optimizer
from
hapi.text.bert.dygraph_optimization
import
DyOptimizer
as
DyOptimizer
from
hapi.text.bert.static_optimization
import
StOptimizer
as
StOptimizer
from
hapi.text.bert.optimization
import
make_optimizer
as
make_optimizer
from
hapi.text.bert.dataloader
import
BertDataLoader
as
BertDataLoader
from
hapi.text.bert.dataloader
import
BertDataLoader
as
BertDataLoader
from
hapi.text.bert.dataloader
import
BertInputExample
as
BertInputExample
from
hapi.text.bert.dataloader
import
BertInputExample
as
BertInputExample
from
hapi.text.tokenizer
import
tokenization
as
tokenization
from
hapi.text.tokenizer
import
tokenization
as
tokenization
...
...
hapi/text/bert/bert.py
浏览文件 @
7114c29e
...
@@ -23,8 +23,8 @@ import numpy as np
...
@@ -23,8 +23,8 @@ import numpy as np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
hapi.model
import
Model
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
to_variable
,
Layer
,
guard
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
to_variable
,
Layer
,
guard
from
hapi.text.text
import
PrePostProcessLayer
,
TransformerEncoder
from
hapi.text.text
import
PrePostProcessLayer
,
TransformerEncoder
from
hapi.text.bert.utils.init
import
init_from_static_model
from
hapi.text.bert.utils.init
import
init_from_static_model
...
@@ -52,7 +52,7 @@ class BertConfig(object):
...
@@ -52,7 +52,7 @@ class BertConfig(object):
print
(
'------------------------------------------------'
)
print
(
'------------------------------------------------'
)
class
BertEncoder
(
Layer
):
class
BertEncoder
(
Model
):
"""
"""
bert
bert
"""
"""
...
...
hapi/text/bert/dygraph_optimization.py
0 → 100755
浏览文件 @
7114c29e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.learning_rate_scheduler
import
LearningRateDecay
class
ConstantLR
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
ConstantLR
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
def
step
(
self
):
return
self
.
learning_rate
class
LinearDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
warmup_steps
,
decay_steps
,
end_learning_rate
=
0.0001
,
power
=
1.0
,
cycle
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
LinearDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
warmup_steps
=
warmup_steps
self
.
decay_steps
=
decay_steps
self
.
end_learning_rate
=
end_learning_rate
self
.
power
=
power
self
.
cycle
=
cycle
def
step
(
self
):
if
self
.
step_num
<
self
.
warmup_steps
:
decayed_lr
=
self
.
learning_rate
*
(
self
.
step_num
/
self
.
warmup_steps
)
decayed_lr
=
self
.
create_lr_var
(
decayed_lr
)
else
:
tmp_step_num
=
self
.
step_num
tmp_decay_steps
=
self
.
decay_steps
if
self
.
cycle
:
div_res
=
fluid
.
layers
.
ceil
(
self
.
create_lr_var
(
tmp_step_num
/
float
(
self
.
decay_steps
)))
if
tmp_step_num
==
0
:
div_res
=
self
.
create_lr_var
(
1.0
)
tmp_decay_steps
=
self
.
decay_steps
*
div_res
else
:
tmp_step_num
=
self
.
create_lr_var
(
tmp_step_num
if
tmp_step_num
<
self
.
decay_steps
else
self
.
decay_steps
)
decayed_lr
=
(
self
.
learning_rate
-
self
.
end_learning_rate
)
*
\
((
1
-
tmp_step_num
/
tmp_decay_steps
)
**
self
.
power
)
+
self
.
end_learning_rate
return
decayed_lr
class
DyOptimizer
(
object
):
def
__init__
(
self
,
warmup_steps
,
num_train_steps
,
learning_rate
,
model_cls
,
weight_decay
,
scheduler
=
'linear_warmup_decay'
,
loss_scaling
=
1.0
,
parameter_list
=
None
):
self
.
warmup_steps
=
warmup_steps
self
.
num_train_steps
=
num_train_steps
self
.
learning_rate
=
learning_rate
self
.
model_cls
=
model_cls
self
.
weight_decay
=
weight_decay
self
.
scheduler
=
scheduler
self
.
loss_scaling
=
loss_scaling
self
.
parameter_list
=
parameter_list
self
.
scheduled_lr
=
0.0
self
.
optimizer
=
self
.
lr_schedule
()
def
lr_schedule
(
self
):
if
self
.
warmup_steps
>
0
:
if
self
.
scheduler
==
'noam_decay'
:
self
.
scheduled_lr
=
fluid
.
dygraph
.
NoamDecay
(
1
/
(
self
.
warmup_steps
*
(
self
.
learning_rate
**
2
)),
self
.
warmup_steps
)
elif
self
.
scheduler
==
'linear_warmup_decay'
:
self
.
scheduled_lr
=
LinearDecay
(
self
.
learning_rate
,
self
.
warmup_steps
,
self
.
num_train_steps
,
0.0
)
else
:
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
self
.
scheduled_lr
,
parameter_list
=
self
.
parameter_list
)
else
:
self
.
scheduled_lr
=
ConstantLR
(
self
.
learning_rate
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
self
.
scheduled_lr
,
parameter_list
=
self
.
parameter_list
)
return
optimizer
def
exclude_from_weight_decay
(
self
,
name
):
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
def
state_dict
(
self
):
return
self
.
optimizer
.
state_dict
()
def
set_dict
(
self
,
state_dict
):
return
self
.
optimizer
.
set_dict
(
state_dict
)
def
get_opti_var_name_list
(
self
):
return
self
.
optimizer
.
get_opti_var_name_list
()
def
current_step_lr
(
self
):
return
self
.
optimizer
.
current_step_lr
()
def
minimize
(
self
,
loss
,
use_data_parallel
=
False
,
model
=
None
):
param_list
=
dict
()
clip_norm_thres
=
1.0
#grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm_thres)
if
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
loss
.
backward
()
if
self
.
weight_decay
>
0
:
for
param
in
self
.
model_cls
.
parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
use_data_parallel
:
assert
model
is
not
None
model
.
apply_collective_grads
()
#_, param_grads = self.optimizer.minimize(loss, grad_clip=grad_clip)
_
,
param_grads
=
self
.
optimizer
.
minimize
(
loss
)
if
self
.
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
self
.
exclude_from_weight_decay
(
param
.
name
):
continue
if
isinstance
(
self
.
scheduled_lr
.
step
(),
float
):
updated_param
=
param
.
numpy
()
-
param_list
[
param
.
name
].
numpy
(
)
*
self
.
weight_decay
*
self
.
scheduled_lr
.
step
()
else
:
updated_param
=
param
.
numpy
(
)
-
param_list
[
param
.
name
].
numpy
(
)
*
self
.
weight_decay
*
self
.
scheduled_lr
.
step
().
numpy
()
updated_param_var
=
fluid
.
dygraph
.
to_variable
(
updated_param
)
param
=
updated_param_var
#param = fluid.layers.reshape(x=updated_param_var, shape=list(updated_param_var.shape))
hapi/text/bert/optimization.py
100755 → 100644
浏览文件 @
7114c29e
...
@@ -11,172 +11,35 @@
...
@@ -11,172 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Optimization and learning rate scheduling."""
from
__future__
import
absolute_import
from
paddle.fluid.framework
import
in_dygraph_mode
from
__future__
import
division
from
hapi.text.bert.dygraph_optimization
import
DyOptimizer
as
DyOptimizer
from
__future__
import
print_function
from
hapi.text.bert.static_optimization
import
StOptimizer
as
StOptimizer
import
numpy
as
np
import
paddle.fluid
as
fluid
def
make_optimizer
(
warmup_steps
,
num_train_steps
,
from
paddle.fluid.dygraph.learning_rate_scheduler
import
LearningRateDecay
learning_rate
,
weight_decay
,
model
,
class
ConstantLR
(
LearningRateDecay
):
scheduler
=
'linear_warmup_decay'
,
def
__init__
(
self
,
learning_rate
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
loss_scaling
=
1.0
,
super
(
ConstantLR
,
self
).
__init__
(
begin
,
step
,
dtype
)
parameter_list
=
None
):
self
.
learning_rate
=
learning_rate
if
in_dygraph_mode
():
def
step
(
self
):
return
DyOptimizer
(
return
self
.
learning_rate
warmup_steps
=
warmup_steps
,
num_train_steps
=
num_train_steps
,
learning_rate
=
learning_rate
,
class
LinearDecay
(
LearningRateDecay
):
model_cls
=
model
,
def
__init__
(
self
,
weight_decay
=
weight_decay
,
learning_rate
,
scheduler
=
scheduler
,
warmup_steps
,
loss_scaling
=
loss_scaling
,
decay_steps
,
parameter_list
=
parameter_list
)
end_learning_rate
=
0.0001
,
else
:
power
=
1.0
,
return
StOptimizer
(
cycle
=
False
,
warmup_steps
=
warmup_steps
,
begin
=
0
,
num_train_steps
=
num_train_steps
,
step
=
1
,
learning_rate
=
learning_rate
,
dtype
=
'float32'
):
weight_decay
=
weight_decay
,
super
(
LinearDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
scheduler
=
scheduler
)
self
.
learning_rate
=
learning_rate
self
.
warmup_steps
=
warmup_steps
self
.
decay_steps
=
decay_steps
self
.
end_learning_rate
=
end_learning_rate
self
.
power
=
power
self
.
cycle
=
cycle
def
step
(
self
):
if
self
.
step_num
<
self
.
warmup_steps
:
decayed_lr
=
self
.
learning_rate
*
(
self
.
step_num
/
self
.
warmup_steps
)
decayed_lr
=
self
.
create_lr_var
(
decayed_lr
)
else
:
tmp_step_num
=
self
.
step_num
tmp_decay_steps
=
self
.
decay_steps
if
self
.
cycle
:
div_res
=
fluid
.
layers
.
ceil
(
self
.
create_lr_var
(
tmp_step_num
/
float
(
self
.
decay_steps
)))
if
tmp_step_num
==
0
:
div_res
=
self
.
create_lr_var
(
1.0
)
tmp_decay_steps
=
self
.
decay_steps
*
div_res
else
:
tmp_step_num
=
self
.
create_lr_var
(
tmp_step_num
if
tmp_step_num
<
self
.
decay_steps
else
self
.
decay_steps
)
decayed_lr
=
(
self
.
learning_rate
-
self
.
end_learning_rate
)
*
\
((
1
-
tmp_step_num
/
tmp_decay_steps
)
**
self
.
power
)
+
self
.
end_learning_rate
return
decayed_lr
class
Optimizer
(
object
):
def
__init__
(
self
,
warmup_steps
,
num_train_steps
,
learning_rate
,
model_cls
,
weight_decay
,
scheduler
=
'linear_warmup_decay'
,
loss_scaling
=
1.0
,
parameter_list
=
None
):
self
.
warmup_steps
=
warmup_steps
self
.
num_train_steps
=
num_train_steps
self
.
learning_rate
=
learning_rate
self
.
model_cls
=
model_cls
self
.
weight_decay
=
weight_decay
self
.
scheduler
=
scheduler
self
.
loss_scaling
=
loss_scaling
self
.
parameter_list
=
parameter_list
self
.
scheduled_lr
=
0.0
self
.
optimizer
=
self
.
lr_schedule
()
def
lr_schedule
(
self
):
if
self
.
warmup_steps
>
0
:
if
self
.
scheduler
==
'noam_decay'
:
self
.
scheduled_lr
=
fluid
.
dygraph
.
NoamDecay
(
1
/
(
self
.
warmup_steps
*
(
self
.
learning_rate
**
2
)),
self
.
warmup_steps
)
elif
self
.
scheduler
==
'linear_warmup_decay'
:
self
.
scheduled_lr
=
LinearDecay
(
self
.
learning_rate
,
self
.
warmup_steps
,
self
.
num_train_steps
,
0.0
)
else
:
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
self
.
scheduled_lr
,
parameter_list
=
self
.
parameter_list
)
else
:
self
.
scheduled_lr
=
ConstantLR
(
self
.
learning_rate
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
self
.
scheduled_lr
,
parameter_list
=
self
.
parameter_list
)
return
optimizer
def
exclude_from_weight_decay
(
self
,
name
):
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
def
state_dict
(
self
):
return
self
.
optimizer
.
state_dict
()
def
set_dict
(
self
,
state_dict
):
return
self
.
optimizer
.
set_dict
(
state_dict
)
def
get_opti_var_name_list
(
self
):
return
self
.
optimizer
.
get_opti_var_name_list
()
def
current_step_lr
(
self
):
return
self
.
optimizer
.
current_step_lr
()
def
minimize
(
self
,
loss
,
use_data_parallel
=
False
,
model
=
None
):
param_list
=
dict
()
clip_norm_thres
=
1.0
#grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm_thres)
if
use_data_parallel
:
loss
=
model
.
scale_loss
(
loss
)
loss
.
backward
()
if
self
.
weight_decay
>
0
:
for
param
in
self
.
model_cls
.
parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
use_data_parallel
:
assert
model
is
not
None
model
.
apply_collective_grads
()
#_, param_grads = self.optimizer.minimize(loss, grad_clip=grad_clip)
_
,
param_grads
=
self
.
optimizer
.
minimize
(
loss
)
if
self
.
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
self
.
exclude_from_weight_decay
(
param
.
name
):
continue
if
isinstance
(
self
.
scheduled_lr
.
step
(),
float
):
updated_param
=
param
.
numpy
()
-
param_list
[
param
.
name
].
numpy
(
)
*
self
.
weight_decay
*
self
.
scheduled_lr
.
step
()
else
:
updated_param
=
param
.
numpy
(
)
-
param_list
[
param
.
name
].
numpy
(
)
*
self
.
weight_decay
*
self
.
scheduled_lr
.
step
().
numpy
()
updated_param_var
=
fluid
.
dygraph
.
to_variable
(
updated_param
)
param
=
updated_param_var
#param = fluid.layers.reshape(x=updated_param_var, shape=list(updated_param_var.shape))
hapi/text/bert/static_optimization.py
浏览文件 @
7114c29e
...
@@ -19,7 +19,6 @@ from __future__ import print_function
...
@@ -19,7 +19,6 @@ from __future__ import print_function
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
utils.fp16
import
create_master_params_grads
,
master_param_to_train_param
,
apply_dynamic_loss_scaling
def
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
):
def
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
):
...
@@ -51,128 +50,95 @@ def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
...
@@ -51,128 +50,95 @@ def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
return
lr
return
lr
def
optimization
(
loss
,
class
StOptimizer
(
fluid
.
optimizer
.
Optimizer
):
def
__init__
(
self
,
warmup_steps
,
warmup_steps
,
num_train_steps
,
num_train_steps
,
learning_rate
,
learning_rate
,
train_program
,
startup_prog
,
weight_decay
,
weight_decay
,
scheduler
=
'linear_warmup_decay'
,
scheduler
=
'linear_warmup_decay'
):
use_fp16
=
False
,
super
(
StOptimizer
,
self
).
__init__
(
use_dynamic_loss_scaling
=
False
,
learning_rate
=
learning_rate
,
init_loss_scaling
=
1.0
,
parameter_list
=
None
,
incr_every_n_steps
=
1000
,
regularization
=
None
,
decr_every_n_nan_or_inf
=
2
,
grad_clip
=
None
,
incr_ratio
=
2.0
,
name
=
None
)
decr_ratio
=
0.8
):
self
.
warmup_steps
=
warmup_steps
self
.
num_train_steps
=
num_train_steps
scheduled_lr
,
loss_scaling
=
None
,
None
self
.
learning_rate
=
learning_rate
if
scheduler
==
'noam_decay'
:
self
.
weight_decay
=
weight_decay
if
warmup_steps
>
0
:
self
.
scheduler
=
scheduler
scheduled_lr
=
fluid
.
layers
.
learning_rate_scheduler
\
.
noam_decay
(
1
/
(
warmup_steps
*
(
learning_rate
**
2
)),
def
minimize
(
self
,
loss
):
warmup_steps
)
train_program
=
fluid
.
default_main_program
()
startup_program
=
fluid
.
default_startup_program
()
if
self
.
scheduler
==
'noam_decay'
:
if
self
.
warmup_steps
>
0
:
scheduled_lr
=
fluid
.
layers
.
learning_rate_scheduler
\
.
noam_decay
(
1
/
(
self
.
warmup_steps
*
(
self
.
learning_rate
**
2
)),
self
.
warmup_steps
)
else
:
print
(
"WARNING: noam decay of learning rate should have postive warmup "
"steps but given {}, using constant learning rate instead!"
.
format
(
self
.
warmup_steps
))
scheduled_lr
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
self
.
learning_rate
,
dtype
=
'float32'
,
persistable
=
True
)
elif
self
.
scheduler
==
'linear_warmup_decay'
:
if
self
.
warmup_steps
>
0
:
scheduled_lr
=
linear_warmup_decay
(
self
.
learning_rate
,
self
.
warmup_steps
,
self
.
num_train_steps
)
else
:
print
(
"WARNING: linear warmup decay of learning rate should have "
"postive warmup steps but given {}, use constant learning rate "
"instead!"
.
format
(
self
.
warmup_steps
))
scheduled_lr
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
self
.
learning_rate
,
dtype
=
'float32'
,
persistable
=
True
)
else
:
else
:
print
(
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"WARNING: noam decay of learning rate should have postive warmup "
"'noam_decay' or 'linear_warmup_decay'"
)
"steps but given {}, using constant learning rate instead!"
.
format
(
warmup_steps
))
scheduled_lr
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
learning_rate
,
dtype
=
'float32'
,
persistable
=
True
)
elif
scheduler
==
'linear_warmup_decay'
:
if
warmup_steps
>
0
:
scheduled_lr
=
linear_warmup_decay
(
learning_rate
,
warmup_steps
,
num_train_steps
)
else
:
print
(
"WARNING: linear warmup decay of learning rate should have "
"postive warmup steps but given {}, use constant learning rate "
"instead!"
.
format
(
warmup_steps
))
scheduled_lr
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
learning_rate
,
dtype
=
'float32'
,
persistable
=
True
)
else
:
raise
ValueError
(
"Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'"
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
))
def
exclude_from_weight_decay
(
param
):
name
=
param
.
name
.
rstrip
(
".master"
)
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
param_list
=
dict
()
if
use_fp16
:
loss_scaling
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"loss_scaling"
),
shape
=
[
1
],
value
=
init_loss_scaling
,
dtype
=
'float32'
,
persistable
=
True
)
loss
*=
loss_scaling
param_grads
=
optimizer
.
backward
(
loss
)
master_param_grads
=
create_master_params_grads
(
param_grads
,
train_program
,
startup_prog
,
loss_scaling
)
if
weight_decay
>
0
:
for
param
,
_
in
master_param_grads
:
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
if
use_dynamic_loss_scaling
:
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
scheduled_lr
)
apply_dynamic_loss_scaling
(
fluid
.
clip
.
set_gradient_clip
(
loss_scaling
,
master_param_grads
,
incr_every_n_steps
,
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
))
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
optimizer
.
apply_gradients
(
master_param_grads
)
def
exclude_from_weight_decay
(
param
):
name
=
param
.
name
.
rstrip
(
".master"
)
if
weight_decay
>
0
:
if
name
.
find
(
"layer_norm"
)
>
-
1
:
for
param
,
grad
in
master_param_grads
:
return
True
if
exclude_from_weight_decay
(
param
):
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
continue
for
suffix
in
bias_suffix
:
with
param
.
block
.
program
.
_optimized_guard
(
if
name
.
endswith
(
suffix
):
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
return
True
updated_param
=
param
-
param_list
[
return
False
param
.
name
]
*
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
master_param_to_train_param
(
master_param_grads
,
param_grads
,
param_list
=
dict
()
train_program
)
else
:
if
self
.
weight_decay
>
0
:
if
weight_decay
>
0
:
for
param
in
train_program
.
all_parameters
():
for
param
in
train_program
.
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
param_list
[
param
.
name
].
stop_gradient
=
True
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
if
weight_decay
>
0
:
if
self
.
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
for
param
,
grad
in
param_grads
:
if
exclude_from_weight_decay
(
param
):
if
exclude_from_weight_decay
(
param
):
continue
continue
with
param
.
block
.
program
.
_optimized_guard
(
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
scheduled_lr
param
.
name
]
*
self
.
weight_decay
*
scheduled_lr
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
return
scheduled_lr
,
loss_scaling
hapi/text/text.py
浏览文件 @
7114c29e
...
@@ -1096,7 +1096,8 @@ class PrePostProcessLayer(Layer):
...
@@ -1096,7 +1096,8 @@ class PrePostProcessLayer(Layer):
self
.
functors
=
[]
self
.
functors
=
[]
for
cmd
in
self
.
process_cmd
:
for
cmd
in
self
.
process_cmd
:
if
cmd
==
"a"
:
# add residual connection
if
cmd
==
"a"
:
# add residual connection
self
.
functors
.
append
(
lambda
x
,
y
:
x
+
y
if
y
else
x
)
self
.
functors
.
append
(
lambda
x
,
y
:
x
+
y
if
y
is
not
None
else
x
)
elif
cmd
==
"n"
:
# add layer normalization
elif
cmd
==
"n"
:
# add layer normalization
if
reused_layer_norm
is
not
None
:
if
reused_layer_norm
is
not
None
:
layer_norm
=
reused_layer_norm
layer_norm
=
reused_layer_norm
...
@@ -1218,7 +1219,7 @@ class MultiHeadAttention(Layer):
...
@@ -1218,7 +1219,7 @@ class MultiHeadAttention(Layer):
# scale dot product attention
# scale dot product attention
product
=
layers
.
matmul
(
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
d_model
**-
0.5
)
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
d_model
**-
0.5
)
if
attn_bias
:
if
attn_bias
is
not
None
:
product
+=
attn_bias
product
+=
attn_bias
weights
=
layers
.
softmax
(
product
)
weights
=
layers
.
softmax
(
product
)
if
self
.
dropout_rate
:
if
self
.
dropout_rate
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录