提交 5a092844 编写于 作者: C chenxuyi

Many updates

1. finetune/ pretrian fp16 amp
2. finetune/ pretrain multi processing
3. ner infer
4. py3 compat
上级 815159f4
*.pyc *.pyc
*.un~ *.un~
*.swp
...@@ -111,6 +111,7 @@ Integrating both phrase information and named entity information enables the mod ...@@ -111,6 +111,7 @@ Integrating both phrase information and named entity information enables the mod
## Release Notes ## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0 - July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab - Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz - Mar 18, 2019: update ERNIE_stable.tgz
...@@ -339,7 +340,7 @@ XNLI is a natural language inference dataset in 15 languages. It was jointly bui ...@@ -339,7 +340,7 @@ XNLI is a natural language inference dataset in 15 languages. It was jointly bui
*\*The DRCD dataset is converted from Traditional Chinese to Simplified Chinese based on tool: https://github.com/skydark/nstools/tree/master/zhtools* *\*The DRCD dataset is converted from Traditional Chinese to Simplified Chinese based on tool: https://github.com/skydark/nstools/tree/master/zhtools*
\* *The pre-training data of ERNIE 1.0 BASE does not contain instances whose length exceeds 128, but other models is pre-trained with the instances whose length are 512. It causes poorer performance of ERNIE 1.0 BASE on long-text tasks. So We have released [ERNIE 1.0 Base(max-len-512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) on July 29th, 2019* \* *The pre-training data of ERNIE 1.0 BASE does not contain instances whose length exceeds 128, but other models is pre-trained with the instances whose length are 512. It causes poorer performance of ERNIE 1.0 BASE on long-text tasks. So We have released [ERNIE 1.0 Base (max-len-512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) on July 29th, 2019*
...@@ -371,7 +372,7 @@ DRCD is an open domain Traditional Chinese machine reading comprehension (MRC) d ...@@ -371,7 +372,7 @@ DRCD is an open domain Traditional Chinese machine reading comprehension (MRC) d
<tr> <tr>
<th><strong>Dataset</strong> <th><strong>Dataset</strong>
<br></th> <br></th>
<th colspan="2"><center><strong>MSRA-NER(SIGHAN2006)</strong></center></th> <th colspan="2"><center><strong>MSRA-NER (SIGHAN2006)</strong></center></th>
<tr> <tr>
<td rowspan="2"> <td rowspan="2">
<p> <p>
...@@ -413,10 +414,10 @@ DRCD is an open domain Traditional Chinese machine reading comprehension (MRC) d ...@@ -413,10 +414,10 @@ DRCD is an open domain Traditional Chinese machine reading comprehension (MRC) d
</tbody> </tbody>
</table> </table>
- **MSRA-NER(SIGHAN2006)** - **MSRA-NER (SIGHAN2006)**
```text ```text
MSRA-NER(SIGHAN2006) dataset is released by MSRA for recognizing the names of people, locations and organizations in text. MSRA-NER (SIGHAN2006) dataset is released by MSRA for recognizing the names of people, locations and organizations in text.
``` ```
#### Results on Sentiment Analysis Task #### Results on Sentiment Analysis Task
...@@ -622,7 +623,7 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u ...@@ -622,7 +623,7 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u
- **BQ Corpus** - **BQ Corpus**
```text ```text
BQ Corpus(Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536] BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536]
``` ```
...@@ -635,6 +636,7 @@ BQ Corpus(Bank Question corpus) is a Chinese corpus for sentence semantic equiva ...@@ -635,6 +636,7 @@ BQ Corpus(Bank Question corpus) is a Chinese corpus for sentence semantic equiva
* [Chinese Datasets](#chinese-datasets) * [Chinese Datasets](#chinese-datasets)
* [Fine-tuning](#fine-tuning) * [Fine-tuning](#fine-tuning)
* [Batchsize and GPU Settings](#batchsize-and-gpu-settings) * [Batchsize and GPU Settings](#batchsize-and-gpu-settings)
* [Multiprocessing and fp16 auto mix-precision finetune](#multiprocessing-and-fp16-auto-mix-precision-finetune)
* [Classification](#classification) * [Classification](#classification)
* [Single Sentence Classification Tasks](#single-sentence-classification-tasks) * [Single Sentence Classification Tasks](#single-sentence-classification-tasks)
* [Sentence Pair Classification Tasks](#sentence-pair-classification-tasks) * [Sentence Pair Classification Tasks](#sentence-pair-classification-tasks)
...@@ -705,14 +707,14 @@ In our experiments, we found that the batch size is important for different task ...@@ -705,14 +707,14 @@ In our experiments, we found that the batch size is important for different task
| Dataset | Batch Size | GPU | | Dataset | Batch Size | GPU |
| ------------ | --------------- | ------------------- | | ------------ | --------------- | ------------------- |
| CoLA | 32 / 64(base) | 1 | | CoLA | 32 / 64 (base) | 1 |
| SST-2 | 64 / 256(base) | 8 | | SST-2 | 64 / 256 (base) | 8 |
| STS-B | 128 | 8 | | STS-B | 128 | 8 |
| QQP | 256 | 8 | | QQP | 256 | 8 |
| MNLI | 256 / 512(base) | 8 | | MNLI | 256 / 512 (base) | 8 |
| QNLI | 256 | 8 | | QNLI | 256 | 8 |
| RTE | 16 / 4(base) | 1 | | RTE | 16 / 4 (base) | 1 |
| MRPC | 16 / 32(base) | 2 | | MRPC | 16 / 32 (base) | 2 |
| WNLI | 8 | 1 | | WNLI | 8 | 1 |
| XNLI | 65536 (tokens) | 8 | | XNLI | 65536 (tokens) | 8 |
| CMRC2018 | 64 | 8 (large) / 4(base) | | CMRC2018 | 64 | 8 (large) / 4(base) |
...@@ -725,6 +727,17 @@ In our experiments, we found that the batch size is important for different task ...@@ -725,6 +727,17 @@ In our experiments, we found that the batch size is important for different task
\* *For MNLI, QNLI,we used 32GB V100, for other tasks we used 22GB P40* \* *For MNLI, QNLI,we used 32GB V100, for other tasks we used 22GB P40*
### Multiprocessing and fp16 auto mix-precision finetune
multiprocessing finetuning can be simply enabled with `finetune_launch.py` in your finetune script.
with multiprocessing finetune paddle can fully utilize your CPU/GPU capacity to accelerate finetuning.
`finetune_launch.py` should place in front of your finetune command. make sure to provide number of process and device id per node by specifiying `--nproc_per_node` and `--selected_gpus`. Number of device ids should match `nproc_per_node` and `CUDA_VISIBLE_DEVICES`, and the indexing should start from 0.
fp16 finetuning can be simply enable by specifing `--use_fp16 true` in your training script (make sure you use have a Tensor Core device). ERNIE will cast computation op to fp16 precision, while maintain storage in fp32 precision. approximately 60% speedup is seen on XNLI finetuning.
dynamic loss scale is used to avoid gradient vanish.
### Classification ### Classification
#### Single Sentence Classification Tasks #### Single Sentence Classification Tasks
......
...@@ -371,10 +371,10 @@ DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从 ...@@ -371,10 +371,10 @@ DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从
</tbody> </tbody>
</table> </table>
- **MSRA-NER(SIGHAN2006)** - **MSRA-NER (SIGHAN2006)**
```text ```text
MSRA-NER(SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,包括人名、地名、机构名。 MSRA-NER (SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,包括人名、地名、机构名。
``` ```
...@@ -640,6 +640,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 ...@@ -640,6 +640,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
* [英文数据](#英文数据) * [英文数据](#英文数据)
* [Fine-tuning 任务](#fine-tuning-任务) * [Fine-tuning 任务](#fine-tuning-任务)
* [运行参数配置](#运行参数配置) * [运行参数配置](#运行参数配置)
* [多进程训练与fp16混合精度](#多进程训练与fp16混合精度)
* [单句和句对分类任务](#单句和句对分类任务) * [单句和句对分类任务](#单句和句对分类任务)
* [单句分类任务](#单句分类任务) * [单句分类任务](#单句分类任务)
* [句对分类任务](#句对分类任务) * [句对分类任务](#句对分类任务)
...@@ -720,8 +721,8 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 ...@@ -720,8 +721,8 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
| MRPC | 16 / 32 (base) | 2 | | MRPC | 16 / 32 (base) | 2 |
| WNLI | 8 | 1 | | WNLI | 8 | 1 |
| XNLI | 65536 (tokens) | 8 | | XNLI | 65536 (tokens) | 8 |
| CMRC2018 | 64 | 8 (large) / 4(base) | | CMRC2018 | 64 | 8 (large) / 4 (base) |
| DRCD | 64 | 8 (large) / 4(base) | | DRCD | 64 | 8 (large) / 4 (base) |
| MSRA-NER(SIGHAN 2006) | 16 | 1 | | MSRA-NER(SIGHAN 2006) | 16 | 1 |
| ChnSentiCorp | 24 | 1 | | ChnSentiCorp | 24 | 1 |
| LCQMC | 32 | 1 | | LCQMC | 32 | 1 |
...@@ -731,6 +732,12 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址 ...@@ -731,6 +732,12 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
\* *MNLI 和 QNLI 的任务中,使用了 32 GB 显存的 V100。除此之外的显卡皆为22 GB 的 P40。* \* *MNLI 和 QNLI 的任务中,使用了 32 GB 显存的 V100。除此之外的显卡皆为22 GB 的 P40。*
### 多进程训练与fp16混合精度
使用`finetune_launch.py`脚本来启动多进程训练 。多进程训练可以提升充分利用多核CPU/多卡GPU 的能力来加速finetune过程。
`finetune_launch.py` 需要放在原来finetune脚本前面, 同时指定每个节点的进程数`--nproc_per_node`, 以及每个节点上的gpu卡号`--selected_gpus`, 一般数量与进程数, `CUDA_VISIBLE_DEVICES`相同且从0开始编号 (参考`script/zh_task/ernie_base/run_xnli.sh`)
只需在训练脚本中加入`--use_fp16 true`即可启用fp16混合精度训练(确保您的硬件支持Tensor Core技术)。ERNIE会将计算Op转换成fp16精度,同时仍然使用fp32精度存储参数。ERNIE使用动态loss scale来避免梯度消失。在XNLI任务上可以观察到大约60%加速。
### 单句和句对分类任务 ### 单句和句对分类任务
......
...@@ -11,11 +11,12 @@ ...@@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference by """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import time import time
...@@ -39,7 +40,7 @@ from reader.task_reader import ClassifyReader ...@@ -39,7 +40,7 @@ from reader.task_reader import ClassifyReader
from model.ernie import ErnieConfig from model.ernie import ErnieConfig
from finetune.classifier import create_model from finetune.classifier import create_model
from utils.args import ArgumentGroup, print_arguments from utils.args import print_arguments, check_cuda, prepare_logger
from utils.init import init_pretraining_params from utils.init import init_pretraining_params
from finetune_args import parser from finetune_args import parser
...@@ -66,6 +67,7 @@ run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for trai ...@@ -66,6 +67,7 @@ run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for trai
run_type_g.add_arg("do_prediction", bool, True, "Whether to do prediction on test set.") run_type_g.add_arg("do_prediction", bool, True, "Whether to do prediction on test set.")
args = parser.parse_args() args = parser.parse_args()
log = logging.getLogger()
# yapf: enable. # yapf: enable.
def main(args): def main(args):
...@@ -113,7 +115,7 @@ def main(args): ...@@ -113,7 +115,7 @@ def main(args):
_, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/')) _, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/'))
dir_name = ckpt_dir + '_inference_model' dir_name = ckpt_dir + '_inference_model'
model_path = os.path.join(args.save_inference_model_path, dir_name) model_path = os.path.join(args.save_inference_model_path, dir_name)
print("save inference model to %s" % model_path) log.info("save inference model to %s" % model_path)
fluid.io.save_inference_model( fluid.io.save_inference_model(
model_path, model_path,
feed_target_names, [probs], feed_target_names, [probs],
...@@ -125,7 +127,7 @@ def main(args): ...@@ -125,7 +127,7 @@ def main(args):
#config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "")) #config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, ""))
config = AnalysisConfig(model_path) config = AnalysisConfig(model_path)
if not args.use_cuda: if not args.use_cuda:
print("disable gpu") log.info("disable gpu")
config.disable_gpu() config.disable_gpu()
# Create PaddlePredictor # Create PaddlePredictor
...@@ -137,7 +139,7 @@ def main(args): ...@@ -137,7 +139,7 @@ def main(args):
epoch=1, epoch=1,
shuffle=False) shuffle=False)
print("-------------- prediction results --------------") log.info("-------------- prediction results --------------")
np.set_printoptions(precision=4, suppress=True) np.set_printoptions(precision=4, suppress=True)
index = 0 index = 0
total_time = 0 total_time = 0
...@@ -156,14 +158,14 @@ def main(args): ...@@ -156,14 +158,14 @@ def main(args):
# parse outputs # parse outputs
output = outputs[0] output = outputs[0]
print(output.name) log.info(output.name)
output_data = output.data.float_data() output_data = output.data.float_data()
#assert len(output_data) == args.num_labels * args.batch_size #assert len(output_data) == args.num_labels * args.batch_size
batch_result = np.array(output_data).reshape((-1, args.num_labels)) batch_result = np.array(output_data).reshape((-1, args.num_labels))
for single_example_probs in batch_result: for single_example_probs in batch_result:
print("{} example\t{}".format(index, single_example_probs)) log.info("{} example\t{}".format(index, single_example_probs))
index += 1 index += 1
print("qps:{}\ttotal_time:{}\ttotal_example:{}\tbatch_size:{}".format(index/total_time, total_time, index, args.batch_size)) log.info("qps:{}\ttotal_time:{}\ttotal_example:{}\tbatch_size:{}".format(index/total_time, total_time, index, args.batch_size))
def array2tensor(ndarray): def array2tensor(ndarray):
...@@ -183,5 +185,6 @@ def array2tensor(ndarray): ...@@ -183,5 +185,6 @@ def array2tensor(ndarray):
return tensor return tensor
if __name__ == '__main__': if __name__ == '__main__':
prepare_logger(log)
print_arguments(args) print_arguments(args)
main(args) main(args)
...@@ -129,8 +129,6 @@ def main(args): ...@@ -129,8 +129,6 @@ def main(args):
pyreader, graph_vars = create_model( pyreader, graph_vars = create_model(
args, pyreader_name='reader', ernie_config=ernie_config) args, pyreader_name='reader', ernie_config=ernie_config)
fluid.memory_optimize(input_program=infer_program)
infer_program = infer_program.clone(for_test=True) infer_program = infer_program.clone(for_test=True)
exe.run(startup_prog) exe.run(startup_prog)
......
...@@ -16,8 +16,11 @@ ...@@ -16,8 +16,11 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import time import time
import logging
import numpy as np import numpy as np
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
...@@ -26,6 +29,7 @@ import paddle.fluid as fluid ...@@ -26,6 +29,7 @@ import paddle.fluid as fluid
from model.ernie import ErnieModel from model.ernie import ErnieModel
log = logging.getLogger(__name__)
def create_model(args, def create_model(args,
pyreader_name, pyreader_name,
......
...@@ -16,12 +16,15 @@ ...@@ -16,12 +16,15 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import time import time
import numpy as np import numpy as np
import os import os
import math import math
import json import json
import logging
import collections import collections
import six import six
...@@ -34,6 +37,8 @@ from model.ernie import ErnieModel ...@@ -34,6 +37,8 @@ from model.ernie import ErnieModel
import tokenization import tokenization
log = logging.getLogger(__name__)
def create_model(args, pyreader_name, ernie_config, is_training): def create_model(args, pyreader_name, ernie_config, is_training):
pyreader = fluid.layers.py_reader( pyreader = fluid.layers.py_reader(
capacity=50, capacity=50,
...@@ -151,7 +156,7 @@ def evaluate(exe, ...@@ -151,7 +156,7 @@ def evaluate(exe,
program=test_program, fetch_list=fetch_list) program=test_program, fetch_list=fetch_list)
for idx in range(np_unique_ids.shape[0]): for idx in range(np_unique_ids.shape[0]):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
print("Processing example: %d" % len(all_results)) log.info("Processing example: %d" % len(all_results))
unique_id = int(np_unique_ids[idx]) unique_id = int(np_unique_ids[idx])
start_logits = [float(x) for x in np_start_logits[idx].flat] start_logits = [float(x) for x in np_start_logits[idx].flat]
end_logits = [float(x) for x in np_end_logits[idx].flat] end_logits = [float(x) for x in np_end_logits[idx].flat]
...@@ -179,7 +184,7 @@ def evaluate(exe, ...@@ -179,7 +184,7 @@ def evaluate(exe,
time_end = time.time() time_end = time.time()
elapsed_time = time_end - time_begin elapsed_time = time_end - time_begin
print( log.info(
"[%s evaluation] em: %f, f1: %f, avg: %f, questions: %d, elapsed time: %f" "[%s evaluation] em: %f, f1: %f, avg: %f, questions: %d, elapsed time: %f"
% (eval_phase, em, f1, avg, total, elapsed_time)) % (eval_phase, em, f1, avg, total, elapsed_time))
...@@ -188,8 +193,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -188,8 +193,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file): output_nbest_file):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
print("Writing predictions to: %s" % (output_prediction_file)) log.info("Writing predictions to: %s" % (output_prediction_file))
print("Writing nbest to: %s" % (output_nbest_file)) log.info("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import time import time
...@@ -23,12 +26,14 @@ import numpy as np ...@@ -23,12 +26,14 @@ import numpy as np
import multiprocessing import multiprocessing
import paddle import paddle
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from six.moves import xrange from six.moves import xrange
from model.ernie import ErnieModel from model.ernie import ErnieModel
log = logging.getLogger(__name__)
def create_model(args, pyreader_name, ernie_config, is_prediction=False): def create_model(args, pyreader_name, ernie_config, is_prediction=False):
pyreader = fluid.layers.py_reader( pyreader = fluid.layers.py_reader(
...@@ -70,9 +75,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): ...@@ -70,9 +75,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
initializer=fluid.initializer.Constant(0.))) initializer=fluid.initializer.Constant(0.)))
infers = fluid.layers.argmax(logits, axis=2) infers = fluid.layers.argmax(logits, axis=2)
ret_labels = fluid.layers.reshape(x=labels, shape=[-1, 1])
ret_infers = fluid.layers.reshape(x=infers, shape=[-1, 1]) ret_infers = fluid.layers.reshape(x=infers, shape=[-1, 1])
lod_labels = fluid.layers.sequence_unpad(labels, seq_lens) lod_labels = fluid.layers.sequence_unpad(labels, seq_lens)
lod_infers = fluid.layers.sequence_unpad(infers, seq_lens) lod_infers = fluid.layers.sequence_unpad(infers, seq_lens)
...@@ -92,18 +95,14 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): ...@@ -92,18 +95,14 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
ce_loss = ce_loss * input_mask ce_loss = ce_loss * input_mask
loss = fluid.layers.mean(x=ce_loss) loss = fluid.layers.mean(x=ce_loss)
if args.use_fp16 and args.loss_scaling > 1.0:
loss *= args.loss_scaling
graph_vars = { graph_vars = {
"inputs": src_ids,
"loss": loss, "loss": loss,
"probs": probs, "probs": probs,
"labels": ret_labels, "seqlen": seq_lens,
"infers": ret_infers,
"num_infer": num_infer, "num_infer": num_infer,
"num_label": num_label, "num_label": num_label,
"num_correct": num_correct, "num_correct": num_correct,
"seq_lens": seq_lens
} }
for k, v in graph_vars.items(): for k, v in graph_vars.items():
...@@ -112,91 +111,6 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): ...@@ -112,91 +111,6 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
return pyreader, graph_vars return pyreader, graph_vars
def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
def extract_bio_chunk(seq):
chunks = []
cur_chunk = None
null_index = tag_num - 1
for index in xrange(len(seq)):
tag = seq[index]
tag_type = tag // 2
tag_pos = tag % 2
if tag == null_index:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = None
continue
if tag_pos == 0:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = {}
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else:
if cur_chunk is None:
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
continue
if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1
else:
chunks.append(cur_chunk)
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
if cur_chunk is not None:
chunks.append(cur_chunk)
return chunks
null_index = tag_num - 1
num_label = 0
num_infer = 0
num_correct = 0
labels = np_labels.reshape([-1]).astype(np.int32).tolist()
infers = np_infers.reshape([-1]).astype(np.int32).tolist()
all_lens = np_lens.reshape([dev_count, -1]).astype(np.int32).tolist()
base_index = 0
for dev_index in xrange(dev_count):
lens = all_lens[dev_index]
max_len = 0
for l in lens:
max_len = max(max_len, l)
for i in xrange(len(lens)):
seq_st = base_index + i * max_len + 1
seq_en = seq_st + (lens[i] - 2)
infer_chunks = extract_bio_chunk(infers[seq_st:seq_en])
label_chunks = extract_bio_chunk(labels[seq_st:seq_en])
num_infer += len(infer_chunks)
num_label += len(label_chunks)
infer_index = 0
label_index = 0
while label_index < len(label_chunks) \
and infer_index < len(infer_chunks):
if infer_chunks[infer_index]["st"] \
< label_chunks[label_index]["st"]:
infer_index += 1
elif infer_chunks[infer_index]["st"] \
> label_chunks[label_index]["st"]:
label_index += 1
else:
if infer_chunks[infer_index]["en"] \
== label_chunks[label_index]["en"] \
and infer_chunks[infer_index]["type"] \
== label_chunks[label_index]["type"]:
num_correct += 1
infer_index += 1
label_index += 1
base_index += max_len * len(lens)
return num_label, num_infer, num_correct
def calculate_f1(num_label, num_infer, num_correct): def calculate_f1(num_label, num_infer, num_correct):
if num_infer == 0: if num_infer == 0:
precision = 0.0 precision = 0.0
...@@ -220,53 +134,85 @@ def evaluate(exe, ...@@ -220,53 +134,85 @@ def evaluate(exe,
pyreader, pyreader,
graph_vars, graph_vars,
tag_num, tag_num,
eval_phase,
dev_count=1): dev_count=1):
fetch_list = [ fetch_list = [
graph_vars["num_infer"].name, graph_vars["num_label"].name, graph_vars["num_infer"].name, graph_vars["num_label"].name,
graph_vars["num_correct"].name graph_vars["num_correct"].name
] ]
if eval_phase == "train": total_label, total_infer, total_correct = 0.0, 0.0, 0.0
fetch_list.append(graph_vars["loss"].name) time_begin = time.time()
if "learning_rate" in graph_vars: pyreader.start()
fetch_list.append(graph_vars["learning_rate"].name) while True:
outputs = exe.run(fetch_list=fetch_list) try:
np_num_infer, np_num_label, np_num_correct, np_loss = outputs[:4] np_num_infer, np_num_label, np_num_correct = exe.run(program=program,
num_label = np.sum(np_num_label) fetch_list=fetch_list)
num_infer = np.sum(np_num_infer) total_infer += np.sum(np_num_infer)
num_correct = np.sum(np_num_correct) total_label += np.sum(np_num_label)
precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct) total_correct += np.sum(np_num_correct)
rets = {
"precision": precision, except fluid.core.EOFException:
"recall": recall, pyreader.reset()
"f1": f1, break
"loss": np.mean(np_loss)
} precision, recall, f1 = calculate_f1(total_label, total_infer,
if "learning_rate" in graph_vars: total_correct)
rets["lr"] = float(outputs[4][0]) time_end = time.time()
return rets return \
"[evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s" \
% (f1, precision, recall, time_end - time_begin)
def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1):
inputs = np_inputs.reshape([-1]).astype(np.int32)
probs = np_probs.reshape([-1, np_probs.shape[-1]])
all_lens = np_lens.reshape([dev_count, -1]).astype(np.int32).tolist()
base_index = 0
out = []
for dev_index in xrange(dev_count):
lens = all_lens[dev_index]
max_len = 0
for l in lens:
max_len = max(max_len, l)
for i in xrange(len(lens)):
seq_st = base_index + i * max_len + 1
seq_en = seq_st + (lens[i] - 2)
prob = probs[seq_st:seq_en, :]
infers = np.argmax(probs, -1)
out.append((
inputs[seq_st:seq_en].tolist(),
infers.tolist(),
probs.tolist()))
base_index += max_len * len(lens)
return out
def predict(exe,
test_program,
test_pyreader,
graph_vars,
dev_count=1):
fetch_list = [
graph_vars["inputs"].name,
graph_vars["probs"].name,
graph_vars["seqlen"].name,
graph_vars["probs"].name,
]
test_pyreader.start()
res = []
while True:
try:
inputs, probs, np_lens, np_probs = exe.run(program=test_program,
fetch_list=fetch_list)
r = chunk_predict(inputs, probs, np_lens, dev_count)
res += r
except fluid.core.EOFException:
test_pyreader.reset()
break
log.info(len(res))
return res
else:
total_label, total_infer, total_correct = 0.0, 0.0, 0.0
time_begin = time.time()
pyreader.start()
while True:
try:
np_num_infer, np_num_label, np_num_correct = exe.run(program=program,
fetch_list=fetch_list)
total_infer += np.sum(np_num_infer)
total_label += np.sum(np_num_label)
total_correct += np.sum(np_num_correct)
except fluid.core.EOFException:
pyreader.reset()
break
precision, recall, f1 = calculate_f1(total_label, total_infer,
total_correct)
time_end = time.time()
print(
"[%s evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s"
% (eval_phase, f1, precision, recall, time_end - time_begin))
...@@ -11,10 +11,12 @@ ...@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import time import time
...@@ -47,10 +49,21 @@ train_g.add_arg("warmup_proportion", float, 0.1, ...@@ -47,10 +49,21 @@ train_g.add_arg("warmup_proportion", float, 0.1,
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.") train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.") train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.") train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("loss_scaling", float, 1.0, train_g.add_arg("use_dynamic_loss_scaling", bool, True, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 102400,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.") "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("test_save", str, "test_result", "test_save")
train_g.add_arg("test_save", str, "./checkpoints/test_result", "test_save")
train_g.add_arg("metric", str, "simple_accuracy", "metric") train_g.add_arg("metric", str, "simple_accuracy", "metric")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
log_g = ArgumentGroup(parser, "logging", "logging related.") log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
...@@ -86,6 +99,7 @@ data_g.add_arg("chunk_scheme", type=str, default="IOB", choices=["IO", "IOB", " ...@@ -86,6 +99,7 @@ data_g.add_arg("chunk_scheme", type=str, default="IOB", choices=["IO", "IOB", "
run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).") run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.") run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.") run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
......
...@@ -16,14 +16,18 @@ ...@@ -16,14 +16,18 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import json import json
import six import six
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from io import open
from model.transformer_encoder import encoder, pre_process_layer from model.transformer_encoder import encoder, pre_process_layer
log = logging.getLogger(__name__)
class ErnieConfig(object): class ErnieConfig(object):
def __init__(self, config_path): def __init__(self, config_path):
...@@ -31,7 +35,7 @@ class ErnieConfig(object): ...@@ -31,7 +35,7 @@ class ErnieConfig(object):
def _parse(self, config_path): def _parse(self, config_path):
try: try:
with open(config_path) as json_file: with open(config_path, 'r', encoding='utf8') as json_file:
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing Ernie model config file '%s'" % raise IOError("Error in parsing Ernie model config file '%s'" %
...@@ -44,8 +48,8 @@ class ErnieConfig(object): ...@@ -44,8 +48,8 @@ class ErnieConfig(object):
def print_config(self): def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)): for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value)) log.info('%s: %s' % (arg, value))
print('------------------------------------------------') log.info('------------------------------------------------')
class ErnieModel(object): class ErnieModel(object):
...@@ -102,7 +106,7 @@ class ErnieModel(object): ...@@ -102,7 +106,7 @@ class ErnieModel(object):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer), name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False) is_sparse=False)
position_emb_out = fluid.layers.embedding( position_emb_out = fluid.layers.embedding(
input=position_ids, input=position_ids,
size=[self._max_position_seq_len, self._emb_size], size=[self._max_position_seq_len, self._emb_size],
...@@ -163,6 +167,10 @@ class ErnieModel(object): ...@@ -163,6 +167,10 @@ class ErnieModel(object):
postprocess_cmd="dan", postprocess_cmd="dan",
param_initializer=self._param_initializer, param_initializer=self._param_initializer,
name='encoder') name='encoder')
if self._dtype == "float16":
self._enc_out = fluid.layers.cast(
x=self._enc_out, dtype=self._emb_dtype)
def get_sequence_output(self): def get_sequence_output(self):
return self._enc_out return self._enc_out
...@@ -171,9 +179,6 @@ class ErnieModel(object): ...@@ -171,9 +179,6 @@ class ErnieModel(object):
"""Get the first feature of each sequence for classification""" """Get the first feature of each sequence for classification"""
next_sent_feat = fluid.layers.slice( next_sent_feat = fluid.layers.slice(
input=self._enc_out, axes=[1], starts=[0], ends=[1]) input=self._enc_out, axes=[1], starts=[0], ends=[1])
if self._dtype == "float16":
next_sent_feat = fluid.layers.cast(
x=next_sent_feat, dtype=self._emb_dtype)
next_sent_feat = fluid.layers.fc( next_sent_feat = fluid.layers.fc(
input=next_sent_feat, input=next_sent_feat,
size=self._emb_size, size=self._emb_size,
...@@ -194,8 +199,6 @@ class ErnieModel(object): ...@@ -194,8 +199,6 @@ class ErnieModel(object):
x=self._enc_out, shape=[-1, self._emb_size]) x=self._enc_out, shape=[-1, self._emb_size])
# extract masked tokens' feature # extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos) mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
if self._dtype == "float16":
mask_feat = fluid.layers.cast(x=mask_feat, dtype=self._emb_dtype)
# transform: fc # transform: fc
mask_trans_feat = fluid.layers.fc( mask_trans_feat = fluid.layers.fc(
...@@ -206,7 +209,7 @@ class ErnieModel(object): ...@@ -206,7 +209,7 @@ class ErnieModel(object):
name='mask_lm_trans_fc.w_0', name='mask_lm_trans_fc.w_0',
initializer=self._param_initializer), initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0')) bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
# transform: layer norm # transform: layer norm
mask_trans_feat = fluid.layers.layer_norm( mask_trans_feat = fluid.layers.layer_norm(
mask_trans_feat, mask_trans_feat,
......
...@@ -16,14 +16,18 @@ ...@@ -16,14 +16,18 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import json import json
import logging
import six import six
import paddle.fluid as fluid import paddle.fluid as fluid
from io import open
from model.transformer_encoder import encoder, pre_process_layer from model.transformer_encoder import encoder, pre_process_layer
log = logging.getLogger(__name__)
class ErnieConfig(object): class ErnieConfig(object):
def __init__(self, config_path): def __init__(self, config_path):
...@@ -31,7 +35,7 @@ class ErnieConfig(object): ...@@ -31,7 +35,7 @@ class ErnieConfig(object):
def _parse(self, config_path): def _parse(self, config_path):
try: try:
with open(config_path) as json_file: with open(config_path, 'r', encoding='utf8') as json_file:
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing Ernie model config file '%s'" % raise IOError("Error in parsing Ernie model config file '%s'" %
...@@ -44,8 +48,8 @@ class ErnieConfig(object): ...@@ -44,8 +48,8 @@ class ErnieConfig(object):
def print_config(self): def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)): for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value)) log.info('%s: %s' % (arg, value))
print('------------------------------------------------') log.info('------------------------------------------------')
class ErnieModel(object): class ErnieModel(object):
......
...@@ -16,10 +16,13 @@ ...@@ -16,10 +16,13 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from utils.fp16 import create_master_params_grads, master_param_to_train_param from utils.fp16 import create_master_params_grads, master_param_to_train_param, apply_dynamic_loss_scaling
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps): def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
...@@ -101,7 +104,7 @@ def optimization(loss, ...@@ -101,7 +104,7 @@ def optimization(loss,
return False return False
param_list = dict() param_list = dict()
loss_scaling = fluid.layers.create_global_var( loss_scaling = fluid.layers.create_global_var(
name=fluid.unique_name.generate("loss_scaling"), name=fluid.unique_name.generate("loss_scaling"),
shape=[1], shape=[1],
......
...@@ -42,8 +42,18 @@ train_g.add_arg("warmup_steps", int, 5000, "Total steps to perform wa ...@@ -42,8 +42,18 @@ train_g.add_arg("warmup_steps", int, 5000, "Total steps to perform wa
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.") train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.") train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.") train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("loss_scaling", float, 1.0, train_g.add_arg("use_dynamic_loss_scaling", bool, True, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 102400,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.") "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
log_g = ArgumentGroup(parser, "logging", "logging related.") log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
......
...@@ -11,9 +11,11 @@ ...@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import numpy as np import numpy as np
...@@ -36,8 +38,10 @@ class ErnieDataReader(object): ...@@ -36,8 +38,10 @@ class ErnieDataReader(object):
filelist, filelist,
vocab_path, vocab_path,
batch_size=4096, batch_size=4096,
in_tokens=True,
max_seq_len=512, max_seq_len=512,
shuffle_files=True, shuffle_files=True,
random_seed=1,
epoch=100, epoch=100,
voc_size=0, voc_size=0,
is_test=False, is_test=False,
...@@ -46,6 +50,8 @@ class ErnieDataReader(object): ...@@ -46,6 +50,8 @@ class ErnieDataReader(object):
self.vocab = self.load_vocab(vocab_path) self.vocab = self.load_vocab(vocab_path)
self.filelist = filelist self.filelist = filelist
self.batch_size = batch_size self.batch_size = batch_size
self.in_tokens = in_tokens
self.random_seed = random_seed
self.shuffle_files = shuffle_files self.shuffle_files = shuffle_files
self.epoch = epoch self.epoch = epoch
self.current_epoch = 0 self.current_epoch = 0
...@@ -60,12 +66,42 @@ class ErnieDataReader(object): ...@@ -60,12 +66,42 @@ class ErnieDataReader(object):
self.mask_id = self.vocab["[MASK]"] self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test self.is_test = is_test
self.generate_neg_sample = generate_neg_sample self.generate_neg_sample = generate_neg_sample
assert self.batch_size > 100, "Current batch size means total token's number, \
it should not be set to too small number." self.trainer_id = 0
self.trainer_nums = 1
self.files = open(filelist).readlines()
self.total_file = len(self.files)
if self.is_test: if self.is_test:
self.epoch = 1 self.epoch = 1
self.shuffle_files = False self.shuffle_files = False
self.global_rng = np.random.RandomState(random_seed)
if self.shuffle_files:
if os.getenv("PADDLE_TRAINER_ID"):
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
if os.getenv("PADDLE_NODES_NUM"):
self.trainer_nums = int(os.getenv("PADDLE_TRAINERS_NUM"))
#renew total_file
self.total_file = len(self.files) // self.trainer_nums * self.trainer_nums
if len(self.files) < self.trainer_nums:
raise RuntimeError('not enouph train file to shard, file:%d num_trainer:%d' % (len(self.files), self.trainer_nums))
tmp_files = []
for each in range(epoch):
each_files = self.files
self.global_rng.shuffle(each_files)
tmp_files += each_files
self.files = tmp_files
#renew epochs
self.epoch = len(self.files) // self.total_file * self.total_file
assert self.total_file > 0, \
"[Error] data_dir is empty or less than %d" % self.trainer_nums
if self.in_tokens:
assert self.batch_size > 100, "Current batch size means total token's number, \
it should not be set to too small number."
def get_progress(self): def get_progress(self):
"""return current progress of traning data """return current progress of traning data
...@@ -75,13 +111,16 @@ class ErnieDataReader(object): ...@@ -75,13 +111,16 @@ class ErnieDataReader(object):
def parse_line(self, line, max_seq_len=512): def parse_line(self, line, max_seq_len=512):
""" parse one line to token_ids, sentence_ids, pos_ids, label """ parse one line to token_ids, sentence_ids, pos_ids, label
""" """
line = line.strip().decode().split(";") line = line.strip().split(";")
assert len(line) == 5, "One sample must have 5 fields!" assert len(line) == 5, \
"One sample must have %d fields!" % 5
(token_ids, sent_ids, pos_ids, seg_labels, label) = line (token_ids, sent_ids, pos_ids, seg_labels, label) = line
token_ids = [int(token) for token in token_ids.split(" ")] token_ids = [int(token) for token in token_ids.split(" ")]
sent_ids = [int(token) for token in sent_ids.split(" ")] sent_ids = [int(token) for token in sent_ids.split(" ")]
pos_ids = [int(token) for token in pos_ids.split(" ")] pos_ids = [int(token) for token in pos_ids.split(" ")]
seg_labels = [int(seg_label) for seg_label in seg_labels.split(" ")] seg_labels = [int(seg_label) for seg_label in seg_labels.split(" ")]
assert len(token_ids) == len(sent_ids) == len(pos_ids) == len( assert len(token_ids) == len(sent_ids) == len(pos_ids) == len(
seg_labels seg_labels
), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids) == len(seg_labels)" ), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids) == len(seg_labels)"
...@@ -94,6 +133,7 @@ class ErnieDataReader(object): ...@@ -94,6 +133,7 @@ class ErnieDataReader(object):
assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file
with gzip.open(file, "rb") as f: with gzip.open(file, "rb") as f:
for line in f: for line in f:
line = line.decode('utf8')
parsed_line = self.parse_line( parsed_line = self.parse_line(
line, max_seq_len=self.max_seq_len) line, max_seq_len=self.max_seq_len)
if parsed_line is None: if parsed_line is None:
...@@ -232,35 +272,63 @@ class ErnieDataReader(object): ...@@ -232,35 +272,63 @@ class ErnieDataReader(object):
print("miss_num:%d\tideal_total_sample_num:%d\tmiss_rate:%f" % print("miss_num:%d\tideal_total_sample_num:%d\tmiss_rate:%f" %
(num_total_miss, pos_sample_num * 2, (num_total_miss, pos_sample_num * 2,
num_total_miss / (pos_sample_num * 2))) num_total_miss / (pos_sample_num * 2)))
def shuffle_samples(self, sample_generator, buffer=1000):
samples = []
try:
while True:
while len(samples) < buffer:
sample = next(sample_generator)
samples.append(sample)
np.random.shuffle(samples)
for sample in samples:
yield sample
samples = []
except StopIteration:
print("stopiteration: reach end of file")
if len(samples) == 0:
yield None
else:
np.random.shuffle(samples)
for sample in samples:
yield sample
def data_generator(self): def data_generator(self):
""" """
data_generator data_generator
""" """
files = open(self.filelist).readlines()
self.total_file = len(files)
assert self.total_file > 0, "[Error] data_dir is empty"
def wrapper(): def wrapper():
def reader(): def reader():
for epoch in range(self.epoch): for epoch in range(self.epoch):
self.current_epoch = epoch + 1 self.current_epoch = epoch + 1
files = self.files
#during training, data are sliced by trainers
if self.shuffle_files: if self.shuffle_files:
np.random.shuffle(files) start = epoch * self.total_file
for index, file in enumerate(files): end = start + self.total_file
file, mask_word_prob = file.strip().split("\t") files = [file_ for index, file_ in enumerate(self.files[start:end]) \
if index % self.trainer_nums == self.trainer_id]
for index, file_ in enumerate(files):
file_, mask_word_prob = file_.strip().split("\t")
mask_word = (np.random.random() < float(mask_word_prob)) mask_word = (np.random.random() < float(mask_word_prob))
self.current_file_index = index + 1 self.current_file_index = (index + 1) * self.trainer_nums
self.current_file = file self.current_file = file_
if mask_word: if mask_word:
self.mask_type = "mask_word" self.mask_type = "mask_word"
else: else:
self.mask_type = "mask_char" self.mask_type = "mask_char"
sample_generator = self.read_file(file) sample_generator = self.read_file(file_)
if not self.is_test and self.generate_neg_sample: if not self.is_test:
sample_generator = self.mixin_negtive_samples( if self.generate_neg_sample:
sample_generator) sample_generator = self.mixin_negtive_samples(
sample_generator)
else:
#shuffle buffered sample
sample_generator = self.shuffle_samples(
sample_generator)
for sample in sample_generator: for sample in sample_generator:
if sample is None: if sample is None:
continue continue
...@@ -272,7 +340,11 @@ class ErnieDataReader(object): ...@@ -272,7 +340,11 @@ class ErnieDataReader(object):
for parsed_line in reader(): for parsed_line in reader():
token_ids, sent_ids, pos_ids, label, seg_labels, mask_word = parsed_line token_ids, sent_ids, pos_ids, label, seg_labels, mask_word = parsed_line
max_len = max(max_len, len(token_ids)) max_len = max(max_len, len(token_ids))
if (len(batch) + 1) * max_len <= batch_size: if self.in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line) batch.append(parsed_line)
total_token_num += len(token_ids) total_token_num += len(token_ids)
else: else:
......
...@@ -11,18 +11,46 @@ ...@@ -11,18 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import sys
import os import os
import csv
import json import json
import random import random
import logging
import numpy as np import numpy as np
import six
from io import open
from collections import namedtuple from collections import namedtuple
import tokenization import tokenization
from batching import pad_batch_data from batching import pad_batch_data
log = logging.getLogger(__name__)
if six.PY3:
from itertools import accumulate
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
def csv_reader(fd, delimiter='\t'):
def gen():
for i in fd:
slots = i.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen()
class BaseReader(object): class BaseReader(object):
def __init__(self, def __init__(self,
vocab_path, vocab_path,
...@@ -58,7 +86,7 @@ class BaseReader(object): ...@@ -58,7 +86,7 @@ class BaseReader(object):
self.num_examples = 0 self.num_examples = 0
if label_map_config: if label_map_config:
with open(label_map_config) as f: with open(label_map_config, encoding='utf8') as f:
self.label_map = json.load(f) self.label_map = json.load(f)
else: else:
self.label_map = None self.label_map = None
...@@ -69,8 +97,8 @@ class BaseReader(object): ...@@ -69,8 +97,8 @@ class BaseReader(object):
def _read_tsv(self, input_file, quotechar=None): def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r") as f: with open(input_file, 'r', encoding='utf8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv_reader(f)
headers = next(reader) headers = next(reader)
Example = namedtuple('Example', headers) Example = namedtuple('Example', headers)
...@@ -225,6 +253,12 @@ class BaseReader(object): ...@@ -225,6 +253,12 @@ class BaseReader(object):
phase=None): phase=None):
examples = self._read_tsv(input_file) examples = self._read_tsv(input_file)
if phase == 'train':
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
examples = examples[trainer_id: (len(examples) //trainer_num) * trainer_num : trainer_num]
log.info('apply sharding %d/%d' % (trainer_id, trainer_num))
def wrapper(): def wrapper():
all_dev_batches = [] all_dev_batches = []
for epoch_index in range(epoch): for epoch_index in range(epoch):
...@@ -242,15 +276,21 @@ class BaseReader(object): ...@@ -242,15 +276,21 @@ class BaseReader(object):
for batch in all_dev_batches: for batch in all_dev_batches:
yield batch yield batch
all_dev_batches = [] all_dev_batches = []
def f():
return wrapper try:
for i in wrapper():
yield i
except Exception as e:
import traceback
traceback.print_exc()
return f
class ClassifyReader(BaseReader): class ClassifyReader(BaseReader):
def _read_tsv(self, input_file, quotechar=None): def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r") as f: with open(input_file, 'r', encoding='utf8') as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv_reader(f)
headers = next(reader) headers = next(reader)
text_indices = [ text_indices = [
index for index, h in enumerate(headers) if h != "label" index for index, h in enumerate(headers) if h != "label"
...@@ -472,7 +512,7 @@ class MRCReader(BaseReader): ...@@ -472,7 +512,7 @@ class MRCReader(BaseReader):
def _read_json(self, input_file, is_training): def _read_json(self, input_file, is_training):
examples = [] examples = []
with open(input_file, "r") as f: with open(input_file, "r", encoding='utf8') as f:
input_data = json.load(f)["data"] input_data = json.load(f)["data"]
for entry in input_data: for entry in input_data:
for paragraph in entry["paragraphs"]: for paragraph in entry["paragraphs"]:
...@@ -507,7 +547,7 @@ class MRCReader(BaseReader): ...@@ -507,7 +547,7 @@ class MRCReader(BaseReader):
actual_text = " ".join(doc_tokens[start_pos:(end_pos actual_text = " ".join(doc_tokens[start_pos:(end_pos
+ 1)]) + 1)])
if actual_text.find(orig_answer_text) == -1: if actual_text.find(orig_answer_text) == -1:
print("Could not find answer: '%s' vs. '%s'", log.info("Could not find answer: '%s' vs. '%s'",
actual_text, orig_answer_text) actual_text, orig_answer_text)
continue continue
else: else:
......
...@@ -16,9 +16,12 @@ ...@@ -16,9 +16,12 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import time import time
import logging
import multiprocessing import multiprocessing
# NOTE(paddle-dev): All of these flags should be # NOTE(paddle-dev): All of these flags should be
...@@ -32,12 +35,13 @@ import reader.task_reader as task_reader ...@@ -32,12 +35,13 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig from model.ernie import ErnieConfig
from finetune.classifier import create_model, evaluate, predict from finetune.classifier import create_model, evaluate, predict
from optimization import optimization from optimization import optimization
from utils.args import print_arguments, check_cuda from utils.args import print_arguments, check_cuda, prepare_logger
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
from utils.cards import get_cards from utils.cards import get_cards
from finetune_args import parser from finetune_args import parser
args = parser.parse_args() args = parser.parse_args()
log = logging.getLogger()
def main(args): def main(args):
...@@ -45,8 +49,9 @@ def main(args): ...@@ -45,8 +49,9 @@ def main(args):
ernie_config.print_config() ernie_config.print_config()
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) dev_list = fluid.cuda_places()
dev_count = fluid.core.get_cuda_device_count() place = dev_list[0]
dev_count = len(dev_list)
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
...@@ -95,10 +100,10 @@ def main(args): ...@@ -95,10 +100,10 @@ def main(args):
max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count
warmup_steps = int(max_train_steps * args.warmup_proportion) warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Device count: %d" % dev_count) log.info("Device count: %d" % dev_count)
print("Num train examples: %d" % num_train_examples) log.info("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps) log.info("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps) log.info("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program() train_program = fluid.Program()
if args.random_seed is not None and args.enable_ce: if args.random_seed is not None and args.enable_ce:
...@@ -121,7 +126,13 @@ def main(args): ...@@ -121,7 +126,13 @@ def main(args):
startup_prog=startup_prog, startup_prog=startup_prog,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
scheduler=args.lr_scheduler, scheduler=args.lr_scheduler,
use_fp16=args.use_fp16) use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
if args.verbose: if args.verbose:
if args.in_tokens: if args.in_tokens:
...@@ -131,7 +142,7 @@ def main(args): ...@@ -131,7 +142,7 @@ def main(args):
else: else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage( lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size) program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" % log.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val or args.do_test: if args.do_val or args.do_test:
...@@ -148,11 +159,36 @@ def main(args): ...@@ -148,11 +159,36 @@ def main(args):
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
nccl2_num_trainers = 1 nccl2_num_trainers = 1
nccl2_trainer_id = 0 nccl2_trainer_id = 0
if args.is_distributed:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
# prepare nccl2 env.
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_program if args.do_train else test_prog,
startup_program=startup_prog)
nccl2_num_trainers = trainers_num
nccl2_trainer_id = trainer_id
exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
if args.do_train: if args.do_train:
if args.init_checkpoint and args.init_pretraining_params: if args.init_checkpoint and args.init_pretraining_params:
print( log.warning(
"WARNING: args 'init_checkpoint' and 'init_pretraining_params' " "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.") "both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint: if args.init_checkpoint:
...@@ -236,14 +272,14 @@ def main(args): ...@@ -236,14 +272,14 @@ def main(args):
verbose += "learning rate: %f" % ( verbose += "learning rate: %f" % (
outputs["learning_rate"] outputs["learning_rate"]
if warmup_steps > 0 else args.learning_rate) if warmup_steps > 0 else args.learning_rate)
print(verbose) log.info(verbose)
current_example, current_epoch = reader.get_train_progress() current_example, current_epoch = reader.get_train_progress()
time_end = time.time() time_end = time.time()
used_time = time_end - time_begin used_time = time_end - time_begin
if args.is_classify: if args.is_classify:
print( log.info(
"epoch: %d, progress: %d/%d, step: %d, ave loss: %f, " "epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
"ave acc: %f, speed: %f steps/s" % "ave acc: %f, speed: %f steps/s" %
(current_epoch, current_example, num_train_examples, (current_epoch, current_example, num_train_examples,
...@@ -252,7 +288,7 @@ def main(args): ...@@ -252,7 +288,7 @@ def main(args):
ce_info.append( ce_info.append(
[outputs["loss"], outputs["accuracy"], used_time]) [outputs["loss"], outputs["accuracy"], used_time])
if args.is_regression: if args.is_regression:
print( log.info(
"epoch: %d, progress: %d/%d, step: %d, ave loss: %f, " "epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
" speed: %f steps/s" % " speed: %f steps/s" %
(current_epoch, current_example, num_train_examples, (current_epoch, current_example, num_train_examples,
...@@ -260,22 +296,23 @@ def main(args): ...@@ -260,22 +296,23 @@ def main(args):
args.skip_steps / used_time)) args.skip_steps / used_time))
time_begin = time.time() time_begin = time.time()
if steps % args.save_steps == 0: if nccl2_trainer_id == 0:
save_path = os.path.join(args.checkpoints, if steps % args.save_steps == 0:
"step_" + str(steps)) save_path = os.path.join(args.checkpoints,
fluid.io.save_persistables(exe, save_path, train_program) "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
if steps % args.validation_steps == 0 or last_epoch != current_epoch: if steps % args.validation_steps == 0 or last_epoch != current_epoch:
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
evaluate_wrapper(args, reader, exe, test_prog, evaluate_wrapper(args, reader, exe, test_prog,
test_pyreader, graph_vars, test_pyreader, graph_vars,
current_epoch, steps) current_epoch, steps)
if args.do_test: if args.do_test:
predict_wrapper(args, reader, exe, test_prog, predict_wrapper(args, reader, exe, test_prog,
test_pyreader, graph_vars, test_pyreader, graph_vars,
current_epoch, steps) current_epoch, steps)
if last_epoch != current_epoch: if last_epoch != current_epoch:
last_epoch = current_epoch last_epoch = current_epoch
...@@ -295,10 +332,10 @@ def main(args): ...@@ -295,10 +332,10 @@ def main(args):
ce_acc = ce_info[-2][1] ce_acc = ce_info[-2][1]
ce_time = ce_info[-2][2] ce_time = ce_info[-2][2]
except: except:
print("ce info error") log.info("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time)) log.info("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time))
print("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss)) log.info("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss))
print("kpis\ttrain_acc_card%s\t%f" % (card_num, ce_acc)) log.info("kpis\ttrain_acc_card%s\t%f" % (card_num, ce_acc))
# final eval on dev set # final eval on dev set
if args.do_val: if args.do_val:
...@@ -320,7 +357,7 @@ def main(args): ...@@ -320,7 +357,7 @@ def main(args):
dev_count=1, dev_count=1,
shuffle=False)) shuffle=False))
print("Final diagnostic") log.info("Final diagnostic")
qids, preds, probs = predict( qids, preds, probs = predict(
test_exe, test_exe,
test_prog, test_prog,
...@@ -334,7 +371,7 @@ def main(args): ...@@ -334,7 +371,7 @@ def main(args):
for id, s, p in zip(qids, preds, probs): for id, s, p in zip(qids, preds, probs):
f.write('{}\t{}\t{}\n'.format(id, s, p)) f.write('{}\t{}\t{}\n'.format(id, s, p))
print("Done final diagnostic, saving to {}".format( log.info("Done final diagnostic, saving to {}".format(
args.diagnostic_save)) args.diagnostic_save))
...@@ -349,7 +386,7 @@ def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -349,7 +386,7 @@ def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=False)) shuffle=False))
print("validation result of dataset {}:".format(ds)) log.info("validation result of dataset {}:".format(ds))
evaluate_info = evaluate( evaluate_info = evaluate(
exe, exe,
test_prog, test_prog,
...@@ -359,7 +396,7 @@ def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -359,7 +396,7 @@ def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
metric=args.metric, metric=args.metric,
is_classify=args.is_classify, is_classify=args.is_classify,
is_regression=args.is_regression) is_regression=args.is_regression)
print(evaluate_info + ', file: {}, epoch: {}, steps: {}'.format( log.info(evaluate_info + ', file: {}, epoch: {}, steps: {}'.format(
ds, epoch, steps)) ds, epoch, steps))
...@@ -379,7 +416,7 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -379,7 +416,7 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
shuffle=False)) shuffle=False))
save_path = save_f + '.' + str(epoch) + '.' + str(steps) save_path = save_f + '.' + str(epoch) + '.' + str(steps)
print("testing {}, save to {}".format(test_f, save_path)) log.info("testing {}, save to {}".format(test_f, save_path))
qids, preds, probs = predict( qids, preds, probs = predict(
exe, exe,
test_prog, test_prog,
...@@ -391,6 +428,9 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -391,6 +428,9 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
save_dir = os.path.dirname(save_path) save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
else:
log.warning('save dir exsits: %s, will skip saving' % save_dir)
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
for id, s, p in zip(qids, preds, probs): for id, s, p in zip(qids, preds, probs):
...@@ -398,6 +438,7 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, ...@@ -398,6 +438,7 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
if __name__ == '__main__': if __name__ == '__main__':
prepare_logger(log)
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
main(args) main(args)
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
import os import os
import time import time
import logging
import multiprocessing import multiprocessing
# NOTE(paddle-dev): All of these flags should be # NOTE(paddle-dev): All of these flags should be
...@@ -32,11 +34,12 @@ import reader.task_reader as task_reader ...@@ -32,11 +34,12 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig from model.ernie import ErnieConfig
from finetune.mrc import create_model, evaluate from finetune.mrc import create_model, evaluate
from optimization import optimization from optimization import optimization
from utils.args import print_arguments from utils.args import print_arguments, prepare_logger
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
from finetune_args import parser from finetune_args import parser
args = parser.parse_args() args = parser.parse_args()
log = logging.getLogger()
def main(args): def main(args):
...@@ -44,8 +47,9 @@ def main(args): ...@@ -44,8 +47,9 @@ def main(args):
ernie_config.print_config() ernie_config.print_config()
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) dev_list = fluid.cuda_places()
dev_count = fluid.core.get_cuda_device_count() place = dev_list[0]
dev_count = len(dev_list)
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
...@@ -70,6 +74,8 @@ def main(args): ...@@ -70,6 +74,8 @@ def main(args):
raise ValueError("For args `do_train`, `do_val` and `do_test`, at " raise ValueError("For args `do_train`, `do_val` and `do_test`, at "
"least one of them must be True.") "least one of them must be True.")
if args.do_test:
assert args.test_save is not None
startup_prog = fluid.Program() startup_prog = fluid.Program()
if args.random_seed is not None: if args.random_seed is not None:
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
...@@ -77,11 +83,12 @@ def main(args): ...@@ -77,11 +83,12 @@ def main(args):
if args.predict_batch_size == None: if args.predict_batch_size == None:
args.predict_batch_size = args.batch_size args.predict_batch_size = args.batch_size
if args.do_train: if args.do_train:
trainers_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
train_data_generator = reader.data_generator( train_data_generator = reader.data_generator(
input_file=args.train_set, input_file=args.train_set,
batch_size=args.batch_size, batch_size=args.batch_size,
epoch=args.epoch, epoch=args.epoch,
dev_count=dev_count, dev_count=trainers_num,
shuffle=True, shuffle=True,
phase="train") phase="train")
...@@ -94,10 +101,10 @@ def main(args): ...@@ -94,10 +101,10 @@ def main(args):
max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count
warmup_steps = int(max_train_steps * args.warmup_proportion) warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Device count: %d" % dev_count) log.info("Device count: %d" % dev_count)
print("Num train examples: %d" % num_train_examples) log.info("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps) log.info("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps) log.info("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program() train_program = fluid.Program()
...@@ -108,7 +115,7 @@ def main(args): ...@@ -108,7 +115,7 @@ def main(args):
pyreader_name='train_reader', pyreader_name='train_reader',
ernie_config=ernie_config, ernie_config=ernie_config,
is_training=True) is_training=True)
scheduled_lr, loss_scaling = optimization( scheduled_lr, _ = optimization(
loss=graph_vars["loss"], loss=graph_vars["loss"],
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
num_train_steps=max_train_steps, num_train_steps=max_train_steps,
...@@ -117,7 +124,13 @@ def main(args): ...@@ -117,7 +124,13 @@ def main(args):
startup_prog=startup_prog, startup_prog=startup_prog,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
scheduler=args.lr_scheduler, scheduler=args.lr_scheduler,
use_fp16=args.use_fp16) use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
if args.verbose: if args.verbose:
if args.in_tokens: if args.in_tokens:
...@@ -127,7 +140,7 @@ def main(args): ...@@ -127,7 +140,7 @@ def main(args):
else: else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage( lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size) program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" % log.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val or args.do_test: if args.do_val or args.do_test:
...@@ -144,11 +157,36 @@ def main(args): ...@@ -144,11 +157,36 @@ def main(args):
nccl2_num_trainers = 1 nccl2_num_trainers = 1
nccl2_trainer_id = 0 nccl2_trainer_id = 0
if args.is_distributed:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
# prepare nccl2 env.
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_program if args.do_train else test_prog,
startup_program=startup_prog)
nccl2_num_trainers = trainers_num
nccl2_trainer_id = trainer_id
exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
if args.do_train: if args.do_train:
if args.init_checkpoint and args.init_pretraining_params: if args.init_checkpoint and args.init_pretraining_params:
print( log.warning(
"WARNING: args 'init_checkpoint' and 'init_pretraining_params' " "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.") "both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint: if args.init_checkpoint:
...@@ -214,12 +252,12 @@ def main(args): ...@@ -214,12 +252,12 @@ def main(args):
verbose += "learning rate: %f" % ( verbose += "learning rate: %f" % (
outputs["learning_rate"] outputs["learning_rate"]
if warmup_steps > 0 else args.learning_rate) if warmup_steps > 0 else args.learning_rate)
print(verbose) log.info(verbose)
current_example, current_epoch = reader.get_train_progress() current_example, current_epoch = reader.get_train_progress()
time_end = time.time() time_end = time.time()
used_time = time_end - time_begin used_time = time_end - time_begin
print("epoch: %d, progress: %d/%d, step: %d, ave loss: %f, " log.info("epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
"speed: %f steps/s" % "speed: %f steps/s" %
(current_epoch, current_example, num_train_examples, (current_epoch, current_example, num_train_examples,
steps, outputs["loss"], args.skip_steps / used_time)) steps, outputs["loss"], args.skip_steps / used_time))
...@@ -277,7 +315,7 @@ def main(args): ...@@ -277,7 +315,7 @@ def main(args):
# final eval on dev set # final eval on dev set
if args.do_val: if args.do_val:
print("Final validation result:") log.info("Final validation result:")
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
reader.data_generator( reader.data_generator(
args.dev_set, args.dev_set,
...@@ -298,7 +336,7 @@ def main(args): ...@@ -298,7 +336,7 @@ def main(args):
# final eval on test set # final eval on test set
if args.do_test: if args.do_test:
print("Final test result:") log.info("Final test result:")
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
reader.data_generator( reader.data_generator(
args.test_set, args.test_set,
...@@ -319,6 +357,8 @@ def main(args): ...@@ -319,6 +357,8 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
prepare_logger(log)
print_arguments(args)
while True: while True:
scope = fluid.core.Scope() scope = fluid.core.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
......
...@@ -16,10 +16,15 @@ ...@@ -16,10 +16,15 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import time import time
import six
import logging
import multiprocessing import multiprocessing
from io import open
# NOTE(paddle-dev): All of these flags should be # NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would # set before `import paddle`. Otherwise, it would
...@@ -32,11 +37,12 @@ import reader.task_reader as task_reader ...@@ -32,11 +37,12 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig from model.ernie import ErnieConfig
from optimization import optimization from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
from utils.args import print_arguments, check_cuda from utils.args import print_arguments, check_cuda, prepare_logger
from finetune.sequence_label import create_model, evaluate from finetune.sequence_label import create_model, evaluate, predict, calculate_f1
from finetune_args import parser from finetune_args import parser
args = parser.parse_args() args = parser.parse_args()
log = logging.getLogger()
def main(args): def main(args):
...@@ -44,12 +50,12 @@ def main(args): ...@@ -44,12 +50,12 @@ def main(args):
ernie_config.print_config() ernie_config.print_config()
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) dev_list = fluid.cuda_places()
dev_count = fluid.core.get_cuda_device_count() place = dev_list[0]
dev_count = len(dev_list)
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
reader = task_reader.SequenceLabelReader( reader = task_reader.SequenceLabelReader(
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
...@@ -85,10 +91,10 @@ def main(args): ...@@ -85,10 +91,10 @@ def main(args):
max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count
warmup_steps = int(max_train_steps * args.warmup_proportion) warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Device count: %d" % dev_count) log.info("Device count: %d" % dev_count)
print("Num train examples: %d" % num_train_examples) log.info("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps) log.info("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps) log.info("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program() train_program = fluid.Program()
...@@ -107,7 +113,13 @@ def main(args): ...@@ -107,7 +113,13 @@ def main(args):
startup_prog=startup_prog, startup_prog=startup_prog,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
scheduler=args.lr_scheduler, scheduler=args.lr_scheduler,
use_fp16=args.use_fp16) use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
if args.verbose: if args.verbose:
if args.in_tokens: if args.in_tokens:
...@@ -117,7 +129,7 @@ def main(args): ...@@ -117,7 +129,7 @@ def main(args):
else: else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage( lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size) program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" % log.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val or args.do_test: if args.do_val or args.do_test:
...@@ -131,11 +143,38 @@ def main(args): ...@@ -131,11 +143,38 @@ def main(args):
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
nccl2_num_trainers = 1
nccl2_trainer_id = 0
if args.is_distributed:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
# prepare nccl2 env.
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_program if args.do_train else test_prog,
startup_program=startup_prog)
nccl2_num_trainers = trainers_num
nccl2_trainer_id = trainer_id
exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
if args.do_train: if args.do_train:
if args.init_checkpoint and args.init_pretraining_params: if args.init_checkpoint and args.init_pretraining_params:
print( log.info(
"WARNING: args 'init_checkpoint' and 'init_pretraining_params' " "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.") "both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint: if args.init_checkpoint:
...@@ -171,7 +210,9 @@ def main(args): ...@@ -171,7 +210,9 @@ def main(args):
use_cuda=args.use_cuda, use_cuda=args.use_cuda,
loss_name=graph_vars["loss"].name, loss_name=graph_vars["loss"].name,
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
main_program=train_program) main_program=train_program,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
train_pyreader.decorate_tensor_provider(train_data_generator) train_pyreader.decorate_tensor_provider(train_data_generator)
else: else:
...@@ -186,8 +227,7 @@ def main(args): ...@@ -186,8 +227,7 @@ def main(args):
if args.do_train: if args.do_train:
train_pyreader.start() train_pyreader.start()
steps = 0 steps = 0
if warmup_steps > 0: graph_vars["learning_rate"] = scheduled_lr
graph_vars["learning_rate"] = scheduled_lr
time_begin = time.time() time_begin = time.time()
while True: while True:
...@@ -196,54 +236,47 @@ def main(args): ...@@ -196,54 +236,47 @@ def main(args):
if steps % args.skip_steps != 0: if steps % args.skip_steps != 0:
train_exe.run(fetch_list=[]) train_exe.run(fetch_list=[])
else: else:
outputs = evaluate(train_exe, train_program, train_pyreader, fetch_list = [
graph_vars, args.num_labels, "train", graph_vars["num_infer"].name, graph_vars["num_label"].name,
dev_count) graph_vars["num_correct"].name,
graph_vars["loss"].name,
graph_vars['learning_rate'].name,
]
out = train_exe.run(fetch_list=fetch_list)
num_infer, num_label, num_correct, np_loss, np_lr = out
lr = float(np_lr[0])
loss = np_loss.mean()
precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct)
if args.verbose: if args.verbose:
verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size( log.info("train pyreader queue size: %d, learning rate: %f" % (train_pyreader.queue.size(),
) lr if warmup_steps > 0 else args.learning_rate))
verbose += "learning rate: %f" % (
outputs["lr"]
if warmup_steps > 0 else args.learning_rate)
print(verbose)
current_example, current_epoch = reader.get_train_progress() current_example, current_epoch = reader.get_train_progress()
time_end = time.time() time_end = time.time()
used_time = time_end - time_begin used_time = time_end - time_begin
print("epoch: %d, progress: %d/%d, step: %d, loss: %f, " log.info("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
"f1: %f, precision: %f, recall: %f, speed: %f steps/s" "f1: %f, precision: %f, recall: %f, speed: %f steps/s"
% (current_epoch, current_example, num_train_examples, % (current_epoch, current_example, num_train_examples,
steps, outputs["loss"], outputs["f1"], steps, loss, f1, precision, recall,
outputs["precision"], outputs["recall"],
args.skip_steps / used_time)) args.skip_steps / used_time))
time_begin = time.time() time_begin = time.time()
if steps % args.save_steps == 0: if nccl2_trainer_id == 0 and steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
if steps % args.validation_steps == 0: if nccl2_trainer_id == 0 and steps % args.validation_steps == 0:
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider( evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
reader.data_generator( current_epoch, steps)
args.dev_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, graph_vars,
args.num_labels, "dev")
# evaluate test set # evaluate test set
if args.do_test: if args.do_test:
test_pyreader.decorate_tensor_provider( predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
reader.data_generator( current_epoch, steps)
args.test_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, graph_vars,
args.num_labels, "test")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps))
...@@ -252,31 +285,65 @@ def main(args): ...@@ -252,31 +285,65 @@ def main(args):
break break
# final eval on dev set # final eval on dev set
if args.do_val: if nccl2_trainer_id ==0 and args.do_val:
test_pyreader.decorate_tensor_provider( evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
reader.data_generator( current_epoch, 'final')
args.dev_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False))
print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels,
"dev")
# final eval on test set if nccl2_trainer_id == 0 and args.do_test:
if args.do_test: predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
current_epoch, 'final')
def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps):
# evaluate dev set
for ds in args.dev_set.split(','): #single card eval
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
reader.data_generator( reader.data_generator(
args.test_set, ds,
batch_size=args.batch_size, batch_size=args.predict_batch_size,
epoch=1, epoch=1,
dev_count=1,
shuffle=False)) shuffle=False))
print("Final test result:") log.info("validation result of dataset {}:".format(ds))
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, info = evaluate(exe, test_prog, test_pyreader, graph_vars,
"test") args.num_labels)
log.info(info + ', file: {}, epoch: {}, steps: {}'.format(
ds, epoch, steps))
def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps):
test_sets = args.test_set.split(',')
save_dirs = args.test_save.split(',')
assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs))
for test_f, save_f in zip(test_sets, save_dirs):
test_pyreader.decorate_tensor_provider(reader.data_generator(
test_f,
batch_size=args.predict_batch_size,
epoch=1,
dev_count=1,
shuffle=False))
save_path = save_f + '.' + str(epoch) + '.' + str(steps)
log.info("testing {}, save to {}".format(test_f, save_path))
res = predict(exe, test_prog, test_pyreader, graph_vars, dev_count=1)
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
tokenizer = reader.tokenizer
rev_label_map = {v: k for k, v in six.iteritems(reader.label_map)}
with open(save_path, 'w', encoding='utf8') as f:
for id, s, p in res:
id = ' '.join(tokenizer.convert_ids_to_tokens(id))
p = ' '.join(['%.5f' % pp[ss] for ss, pp in zip(s, p)])
s = ' '.join([rev_label_map[ss]for ss in s])
f.write('{}\t{}\t{}\n'.format(id, s, p))
if __name__ == '__main__': if __name__ == '__main__':
prepare_logger(log)
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
main(args) main(args)
...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0 ...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3 export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u run_mrc.py --use_cuda true\ python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 16 \ --batch_size 16 \
--in_tokens false\ --in_tokens false\
--use_fast_executor true \ --use_fast_executor true \
......
...@@ -4,7 +4,13 @@ export FLAGS_eager_delete_tensor_gb=0 ...@@ -4,7 +4,13 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \ --use_cuda true \
--verbose true \ --verbose true \
--do_train true \ --do_train true \
......
...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0 ...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3 export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u run_mrc.py --use_cuda true\ python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 16 \ --batch_size 16 \
--in_tokens false\ --in_tokens false\
--use_fast_executor true \ --use_fast_executor true \
......
...@@ -2,7 +2,7 @@ set -eux ...@@ -2,7 +2,7 @@ set -eux
export FLAGS_eager_delete_tensor_gb=0 export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLDE_DEVICES=0
python -u run_sequence_labeling.py \ python -u run_sequence_labeling.py \
--use_cuda true \ --use_cuda true \
...@@ -15,7 +15,7 @@ python -u run_sequence_labeling.py \ ...@@ -15,7 +15,7 @@ python -u run_sequence_labeling.py \
--chunk_scheme "IOB" \ --chunk_scheme "IOB" \
--label_map_config ${TASK_DATA_PATH}/msra_ner/label_map.json \ --label_map_config ${TASK_DATA_PATH}/msra_ner/label_map.json \
--train_set ${TASK_DATA_PATH}/msra_ner/train.tsv \ --train_set ${TASK_DATA_PATH}/msra_ner/train.tsv \
--dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv \ --dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv,${TASK_DATA_PATH}/msra_ner/test.tsv \
--test_set ${TASK_DATA_PATH}/msra_ner/test.tsv \ --test_set ${TASK_DATA_PATH}/msra_ner/test.tsv \
--vocab_path ${MODEL_PATH}/vocab.txt \ --vocab_path ${MODEL_PATH}/vocab.txt \
--ernie_config_path ${MODEL_PATH}/ernie_config.json \ --ernie_config_path ${MODEL_PATH}/ernie_config.json \
...@@ -24,6 +24,7 @@ python -u run_sequence_labeling.py \ ...@@ -24,6 +24,7 @@ python -u run_sequence_labeling.py \
--weight_decay 0.01 \ --weight_decay 0.01 \
--warmup_proportion 0.0 \ --warmup_proportion 0.0 \
--validation_steps 100 \ --validation_steps 100 \
--use_fp16 false \
--epoch 6 \ --epoch 6 \
--max_seq_len 256 \ --max_seq_len 256 \
--learning_rate 5e-5 \ --learning_rate 5e-5 \
......
...@@ -4,29 +4,36 @@ export FLAGS_eager_delete_tensor_gb=0 ...@@ -4,29 +4,36 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \ python ./finetune_launch.py \
--use_cuda true \ --nproc_per_node 8 \
--do_train true \ --selected_gpus 0,1,2,3,4,5,6,7 \
--do_val true \ --node_ips $(hostname -i) \
--do_test false \ --node_id 0 \
--verbose true \ run_classifier.py \
--batch_size 8192 \ --use_cuda true \
--in_tokens true \ --do_train true \
--init_pretraining_params ${MODEL_PATH}/params \ --do_val true \
--train_set ${TASK_DATA_PATH}/xnli/train.tsv \ --do_test false \
--dev_set ${TASK_DATA_PATH}/xnli/dev.tsv,${TASK_DATA_PATH}/xnli/test.tsv \ --verbose true \
--vocab_path ${MODEL_PATH}/vocab.txt \ --in_tokens true \
--label_map ${TASK_DATA_PATH}/xnli/label_map.json \ --batch_size 8192 \
--ernie_config_path ${MODEL_PATH}/ernie_config.json \ --train_set ${TASK_DATA_PATH}/xnli/train.tsv \
--checkpoints ./checkpoints \ --dev_set ${TASK_DATA_PATH}/xnli/dev.tsv,${TASK_DATA_PATH}/xnli/test.tsv \
--save_steps 1000 \ --label_map ${TASK_DATA_PATH}/xnli/label_map.json \
--weight_decay 0.01 \ --vocab_path ${MODEL_PATH}/vocab.txt \
--warmup_proportion 0.0 \ --ernie_config_path ${MODEL_PATH}/ernie_config.json \
--validation_steps 25 \ --init_pretraining_params ${MODEL_PATH}/params \
--epoch 3 \ --checkpoints ./checkpoints \
--max_seq_len 512 \ --save_steps 1000 \
--learning_rate 1e-4 \ --weight_decay 0.01 \
--skip_steps 10 \ --warmup_proportion 0.0 \
--num_iteration_per_drop_scope 1 \ --use_fp16 false \
--num_labels 3 \ --validation_steps 100 \
--random_seed 1 --epoch 3 \
--max_seq_len 512 \
--learning_rate 1e-4 \
--skip_steps 10 \
--num_iteration_per_drop_scope 1 \
--num_labels 3 \
--random_seed 1
...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0.0 ...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_mrc.py --use_cuda true\ python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 8 \ --batch_size 8 \
--in_tokens false\ --in_tokens false\
--use_fast_executor true \ --use_fast_executor true \
......
...@@ -3,7 +3,12 @@ set -eux ...@@ -3,7 +3,12 @@ set -eux
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \ python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \ --use_cuda true \
--verbose true \ --verbose true \
--do_train true \ --do_train true \
......
...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0.0 ...@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_mrc.py --use_cuda true\ python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 8 \ --batch_size 8 \
--in_tokens false\ --in_tokens false\
--use_fast_executor true \ --use_fast_executor true \
......
...@@ -14,15 +14,16 @@ python -u run_sequence_labeling.py \ ...@@ -14,15 +14,16 @@ python -u run_sequence_labeling.py \
--chunk_scheme "IOB" \ --chunk_scheme "IOB" \
--label_map_config ${TASK_DATA_PATH}/msra_ner/label_map.json \ --label_map_config ${TASK_DATA_PATH}/msra_ner/label_map.json \
--train_set ${TASK_DATA_PATH}/msra_ner/train.tsv \ --train_set ${TASK_DATA_PATH}/msra_ner/train.tsv \
--dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv \ --dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv,${TASK_DATA_PATH}/msra_ner/test.tsv \
--test_set ${TASK_DATA_PATH}/msra_ner/test.tsv \ --test_set ${TASK_DATA_PATH}/msra_ner/test.tsv \
--vocab_path config/vocab.txt \ --vocab_path ${MODEL_PATH}/vocab.txt \
--ernie_config_path config/ernie_config.json \ --ernie_config_path ${MODEL_PATH}/ernie_config.json \
--checkpoints ./checkpoints \ --checkpoints ./checkpoints \
--save_steps 100000 \ --save_steps 100000 \
--weight_decay 0.01 \ --weight_decay 0.01 \
--warmup_proportion 0.0 \ --warmup_proportion 0.0 \
--validation_steps 100 \ --validation_steps 100 \
--use_fp16 false \
--epoch 6 \ --epoch 6 \
--max_seq_len 256 \ --max_seq_len 256 \
--learning_rate 1e-5 \ --learning_rate 1e-5 \
......
...@@ -3,7 +3,13 @@ set -eux ...@@ -3,7 +3,13 @@ set -eux
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \ --use_cuda true \
--do_train true \ --do_train true \
--do_val true \ --do_val true \
......
...@@ -3,8 +3,12 @@ set -eux ...@@ -3,8 +3,12 @@ set -eux
export FLAGS_eager_delete_tensor_gb=0 export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python ./pretrain_launch.py \
python -u ./train.py --use_cuda True \ --nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
./train.py --use_cuda True \
--is_distributed False\ --is_distributed False\
--use_fast_executor True \ --use_fast_executor True \
--weight_sharing True \ --weight_sharing True \
...@@ -19,6 +23,7 @@ python -u ./train.py --use_cuda True \ ...@@ -19,6 +23,7 @@ python -u ./train.py --use_cuda True \
--save_steps 10000 \ --save_steps 10000 \
--ernie_config_path ./config/ernie_config.json \ --ernie_config_path ./config/ernie_config.json \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--use_fp16 false \
--weight_decay 0.01 \ --weight_decay 0.01 \
--max_seq_len 512 \ --max_seq_len 512 \
--skip_steps 10 --skip_steps 10
...@@ -17,6 +17,10 @@ ...@@ -17,6 +17,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from io import open
import collections import collections
import unicodedata import unicodedata
...@@ -69,15 +73,15 @@ def printable_text(text): ...@@ -69,15 +73,15 @@ def printable_text(text):
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
fin = open(vocab_file) with open(vocab_file, encoding='utf8') as fin:
for num, line in enumerate(fin): for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t") items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2: if len(items) > 2:
break break
token = items[0] token = items[0]
index = items[1] if len(items) == 2 else num index = items[1] if len(items) == 2 else num
token = token.strip() token = token.strip()
vocab[token] = int(index) vocab[token] = int(index)
return vocab return vocab
......
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""ERNIE pretraining.""" """ERNIE pretraining."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import time import time
import multiprocessing import multiprocessing
import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -27,11 +29,12 @@ import paddle.fluid as fluid ...@@ -27,11 +29,12 @@ import paddle.fluid as fluid
from reader.pretraining import ErnieDataReader from reader.pretraining import ErnieDataReader
from model.ernie_v1 import ErnieModel, ErnieConfig from model.ernie_v1 import ErnieModel, ErnieConfig
from optimization import optimization from optimization import optimization
from utils.args import print_arguments, check_cuda from utils.args import print_arguments, check_cuda, prepare_logger
from utils.init import init_checkpoint, init_pretraining_params from utils.init import init_checkpoint, init_pretraining_params
from pretrain_args import parser from pretrain_args import parser
log = logging.getLogger()
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
...@@ -65,9 +68,6 @@ def create_model(pyreader_name, ernie_config): ...@@ -65,9 +68,6 @@ def create_model(pyreader_name, ernie_config):
next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output( next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output(
mask_label, mask_pos, labels) mask_label, mask_pos, labels)
if args.use_fp16 and args.loss_scaling > 1.0:
total_loss *= args.loss_scaling
return pyreader, next_sent_acc, mask_lm_loss, total_loss return pyreader, next_sent_acc, mask_lm_loss, total_loss
...@@ -114,7 +114,7 @@ def predict_wrapper(args, ...@@ -114,7 +114,7 @@ def predict_wrapper(args,
cost += each_total_cost cost += each_total_cost
steps += 1 steps += 1
if args.do_test and steps % args.skip_steps == 0: if args.do_test and steps % args.skip_steps == 0:
print("[test_set] steps: %d" % steps) log.info("[test_set] steps: %d" % steps)
except fluid.core.EOFException: except fluid.core.EOFException:
pyreader.reset() pyreader.reset()
...@@ -151,9 +151,9 @@ def test(args): ...@@ -151,9 +151,9 @@ def test(args):
pyreader=test_pyreader, pyreader=test_pyreader,
fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name]) fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name])
print("test begin") log.info("test begin")
loss, lm_loss, acc, steps, speed = predict() loss, lm_loss, acc, steps, speed = predict()
print( log.info(
"[test_set] loss: %f, global ppl: %f, next_sent_acc: %f, speed: %f steps/s" "[test_set] loss: %f, global ppl: %f, next_sent_acc: %f, speed: %f steps/s"
% (np.mean(np.array(loss) / steps), % (np.mean(np.array(loss) / steps),
np.exp(np.mean(np.array(lm_loss) / steps)), np.exp(np.mean(np.array(lm_loss) / steps)),
...@@ -161,7 +161,7 @@ def test(args): ...@@ -161,7 +161,7 @@ def test(args):
def train(args): def train(args):
print("pretraining start") log.info("pretraining start")
ernie_config = ErnieConfig(args.ernie_config_path) ernie_config = ErnieConfig(args.ernie_config_path)
ernie_config.print_config() ernie_config.print_config()
...@@ -171,7 +171,7 @@ def train(args): ...@@ -171,7 +171,7 @@ def train(args):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
pyreader_name='train_reader', ernie_config=ernie_config) pyreader_name='train_reader', ernie_config=ernie_config)
scheduled_lr, loss_scaling = optimization( scheduled_lr, _ = optimization(
loss=total_loss, loss=total_loss,
warmup_steps=args.warmup_steps, warmup_steps=args.warmup_steps,
num_train_steps=args.num_train_steps, num_train_steps=args.num_train_steps,
...@@ -180,7 +180,14 @@ def train(args): ...@@ -180,7 +180,14 @@ def train(args):
startup_prog=startup_prog, startup_prog=startup_prog,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
scheduler=args.lr_scheduler, scheduler=args.lr_scheduler,
use_fp16=args.use_fp16) use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
fluid.memory_optimize( fluid.memory_optimize(
input_program=train_program, input_program=train_program,
...@@ -196,31 +203,34 @@ def train(args): ...@@ -196,31 +203,34 @@ def train(args):
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
if len(fluid.cuda_places()) == 0:
raise RuntimeError('not cuda device cound, check ur env setting')
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(0) place = fluid.cuda_places()[0]
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("Device count %d" % dev_count) log.info("Device count %d" % dev_count)
print("theoretical memory usage: ") log.info("theoretical memory usage: ")
print(fluid.contrib.memory_usage( log.info(fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size // args.max_seq_len)) program=train_program, batch_size=args.batch_size // args.max_seq_len))
nccl2_num_trainers = 1 nccl2_num_trainers = 1
nccl2_trainer_id = 0 nccl2_trainer_id = 0
print("args.is_distributed:", args.is_distributed) log.info("args.is_distributed: %s" % args.is_distributed)
if args.is_distributed: if args.is_distributed:
worker_endpoints_env = os.getenv("worker_endpoints") worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
worker_endpoints = worker_endpoints_env.split(",") worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints) trainers_num = len(worker_endpoints)
current_endpoint = os.getenv("current_endpoint") current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
trainer_id = worker_endpoints.index(current_endpoint) trainer_id = worker_endpoints.index(current_endpoint)
if trainer_id == 0: if trainer_id == 0:
print("train_id == 0, sleep 60s") log.info("train_id == 0, sleep 60s")
time.sleep(60) time.sleep(60)
print("worker_endpoints:{} trainers_num:{} current_endpoint:{} \ log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num, trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id)) current_endpoint, trainer_id))
...@@ -309,13 +319,13 @@ def train(args): ...@@ -309,13 +319,13 @@ def train(args):
lm_cost.extend(each_mask_lm_cost) lm_cost.extend(each_mask_lm_cost)
cost.extend(each_total_cost) cost.extend(each_total_cost)
print("feed_queue size", train_pyreader.queue.size()) log.info("feed_queue size %d" % train_pyreader.queue.size())
time_end = time.time() time_end = time.time()
used_time = time_end - time_begin used_time = time_end - time_begin
epoch, current_file_index, total_file, current_file, mask_type = data_reader.get_progress( epoch, current_file_index, total_file, current_file, mask_type = data_reader.get_progress(
) )
print("current learning_rate:%f" % np_lr[0]) log.info("current learning_rate:%f" % np_lr[0])
print( log.info(
"epoch: %d, progress: %d/%d, step: %d, loss: %f, " "epoch: %d, progress: %d/%d, step: %d, loss: %f, "
"ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s, mask_type: %s" "ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s, mask_type: %s"
% (epoch, current_file_index, total_file, steps, % (epoch, current_file_index, total_file, steps,
...@@ -335,7 +345,7 @@ def train(args): ...@@ -335,7 +345,7 @@ def train(args):
if args.valid_filelist and steps % args.validation_steps == 0: if args.valid_filelist and steps % args.validation_steps == 0:
vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict( vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict(
) )
print("[validation_set] epoch: %d, step: %d, " log.info("[validation_set] epoch: %d, step: %d, "
"loss: %f, global ppl: %f, batch-averged ppl: %f, " "loss: %f, global ppl: %f, batch-averged ppl: %f, "
"next_sent_acc: %f, speed: %f steps/s" % "next_sent_acc: %f, speed: %f steps/s" %
(epoch, steps, np.mean(np.array(vali_cost) / vali_steps), (epoch, steps, np.mean(np.array(vali_cost) / vali_steps),
...@@ -349,6 +359,7 @@ def train(args): ...@@ -349,6 +359,7 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
prepare_logger(log)
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
if args.do_test: if args.do_test:
......
...@@ -12,17 +12,35 @@ ...@@ -12,17 +12,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Arguments for configuration.""" """Arguments for configuration."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import six import six
import argparse import argparse
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
log = logging.getLogger(__name__)
def prepare_logger(logger, debug=False, save_to_file=None):
formatter = logging.Formatter(fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s')
console_hdl = logging.StreamHandler()
console_hdl.setFormatter(formatter)
logger.addHandler(console_hdl)
if save_to_file is not None and not os.path.exits(save_to_file):
file_hdl = logging.FileHandler(save_to_file)
file_hdl.setFormatter(formatter)
logger.addHandler(file_hdl)
logger.setLevel(logging.DEBUG)
logger.propagate = False
def str2bool(v): def str2bool(v):
# because argparse does not support to parse "true, False" as python # because argparse does not support to parse "true, False" as python
# boolean directly # boolean directly
...@@ -33,10 +51,11 @@ class ArgumentGroup(object): ...@@ -33,10 +51,11 @@ class ArgumentGroup(object):
def __init__(self, parser, title, des): def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des) self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs): def add_arg(self, name, type, default, help, positional_arg=False, **kwargs):
prefix = "" if positional_arg else "--"
type = str2bool if type == bool else type type = str2bool if type == bool else type
self._group.add_argument( self._group.add_argument(
"--" + name, prefix + name,
default=default, default=default,
type=type, type=type,
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.',
...@@ -44,10 +63,10 @@ class ArgumentGroup(object): ...@@ -44,10 +63,10 @@ class ArgumentGroup(object):
def print_arguments(args): def print_arguments(args):
print('----------- Configuration Arguments -----------') log.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))): for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value)) log.info('%s: %s' % (arg, value))
print('------------------------------------------------') log.info('------------------------------------------------')
def check_cuda(use_cuda, err = \ def check_cuda(use_cuda, err = \
...@@ -56,7 +75,7 @@ def check_cuda(use_cuda, err = \ ...@@ -56,7 +75,7 @@ def check_cuda(use_cuda, err = \
): ):
try: try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False: if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err) log.error(err)
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
...@@ -11,7 +11,11 @@ ...@@ -11,7 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
''' '''
Evaluation script for CMRC 2018 Evaluation script for CMRC 2018
version: v5 version: v5
...@@ -6,22 +19,25 @@ Note: ...@@ -6,22 +19,25 @@ Note:
v5 formatted output, add usage description v5 formatted output, add usage description
v4 fixed segmentation issues v4 fixed segmentation issues
''' '''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
import string import string
import re import re
import argparse import argparse
import json import json
import sys import sys
reload(sys)
sys.setdefaultencoding('utf8')
import nltk import nltk
import pdb import pdb
# split Chinese with English # split Chinese with English
def mixed_segmentation(in_str, rm_punc=False): def mixed_segmentation(in_str, rm_punc=False):
in_str = str(in_str).decode('utf-8').lower().strip() in_str = in_str.lower().strip()
segs_out = [] segs_out = []
temp_str = "" temp_str = ""
sp_char = [ sp_char = [
...@@ -32,7 +48,7 @@ def mixed_segmentation(in_str, rm_punc=False): ...@@ -32,7 +48,7 @@ def mixed_segmentation(in_str, rm_punc=False):
for char in in_str: for char in in_str:
if rm_punc and char in sp_char: if rm_punc and char in sp_char:
continue continue
if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char: if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "": if temp_str != "":
ss = nltk.word_tokenize(temp_str) ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss) segs_out.extend(ss)
...@@ -51,7 +67,7 @@ def mixed_segmentation(in_str, rm_punc=False): ...@@ -51,7 +67,7 @@ def mixed_segmentation(in_str, rm_punc=False):
# remove punctuation # remove punctuation
def remove_punctuation(in_str): def remove_punctuation(in_str):
in_str = str(in_str).decode('utf-8').lower().strip() in_str = in_str.lower().strip()
sp_char = [ sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
...@@ -102,7 +118,7 @@ def evaluate(ground_truth_file, prediction_file): ...@@ -102,7 +118,7 @@ def evaluate(ground_truth_file, prediction_file):
skip_count += 1 skip_count += 1
continue continue
prediction = str(prediction_file[query_id]) prediction = prediction_file[query_id]
f1 += calc_f1_score(answers, prediction) f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction) em += calc_em_score(answers, prediction)
......
...@@ -16,27 +16,20 @@ from __future__ import print_function ...@@ -16,27 +16,20 @@ from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
def append_cast_op(i, o, prog):
def cast_fp16_to_fp32(i, o, prog): """
prog.global_block().append_op( Append a cast op in a given Program to cast input `i` to data type `o.dtype`.
type="cast", Args:
inputs={"X": i}, i (Variable): The input Variable.
outputs={"Out": o}, o (Variable): The output Variable.
attrs={ prog (Program): The Program to append cast op.
"in_dtype": fluid.core.VarDesc.VarType.FP16, """
"out_dtype": fluid.core.VarDesc.VarType.FP32
})
def cast_fp32_to_fp16(i, o, prog):
prog.global_block().append_op( prog.global_block().append_op(
type="cast", type="cast",
inputs={"X": i}, inputs={"X": i},
outputs={"Out": o}, outputs={"Out": o},
attrs={ attrs={"in_dtype": i.dtype,
"in_dtype": fluid.core.VarDesc.VarType.FP32, "out_dtype": o.dtype})
"out_dtype": fluid.core.VarDesc.VarType.FP16
})
def copy_to_master_param(p, block): def copy_to_master_param(p, block):
...@@ -59,32 +52,66 @@ def copy_to_master_param(p, block): ...@@ -59,32 +52,66 @@ def copy_to_master_param(p, block):
return new_p return new_p
def apply_dynamic_loss_scaling(loss_scaling, master_params_grads,
incr_every_n_steps, decr_every_n_nan_or_inf,
incr_ratio, decr_ratio):
_incr_every_n_steps = fluid.layers.fill_constant(
shape=[1], dtype='int32', value=incr_every_n_steps)
_decr_every_n_nan_or_inf = fluid.layers.fill_constant(
shape=[1], dtype='int32', value=decr_every_n_nan_or_inf)
_num_good_steps = fluid.layers.create_global_var(
name=fluid.unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
_num_bad_steps = fluid.layers.create_global_var(
name=fluid.unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
grads = [fluid.layers.reduce_sum(g) for [_, g] in master_params_grads]
all_grads = fluid.layers.concat(grads)
all_grads_sum = fluid.layers.reduce_sum(all_grads)
is_overall_finite = fluid.layers.isfinite(all_grads_sum)
update_loss_scaling(is_overall_finite, loss_scaling, _num_good_steps,
_num_bad_steps, _incr_every_n_steps,
_decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with fluid.layers.Switch() as switch:
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in master_params_grads:
fluid.layers.assign(fluid.layers.zeros_like(g), g)
def create_master_params_grads(params_grads, main_prog, startup_prog, def create_master_params_grads(params_grads, main_prog, startup_prog,
loss_scaling): loss_scaling):
master_params_grads = [] master_params_grads = []
tmp_role = main_prog._current_role
OpRole = fluid.core.op_proto_and_checker_maker.OpRole
main_prog._current_role = OpRole.Backward
for p, g in params_grads: for p, g in params_grads:
with main_prog._optimized_guard([p, g]):
# create master parameters # create master parameters
master_param = copy_to_master_param(p, main_prog.global_block()) master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable( startup_master_param = startup_prog.global_block()._clone_variable(
master_param) master_param)
startup_p = startup_prog.global_block().var(p.name) startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog) append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients # cast fp16 gradients to fp32 before apply gradients
if g.name.find("layer_norm") > -1: if g.name.find("layer_norm") > -1:
if loss_scaling > 1: scaled_g = g / loss_scaling
scaled_g = g / float(loss_scaling) master_params_grads.append([p, scaled_g])
else: continue
scaled_g = g master_grad = fluid.layers.cast(g, "float32")
master_params_grads.append([p, scaled_g]) master_grad = master_grad / loss_scaling
continue master_params_grads.append([master_param, master_grad])
master_grad = fluid.layers.cast(g, "float32")
if loss_scaling > 1:
master_grad = master_grad / float(loss_scaling)
master_params_grads.append([master_param, master_grad])
main_prog._current_role = tmp_role
return master_params_grads return master_params_grads
...@@ -94,4 +121,80 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog): ...@@ -94,4 +121,80 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
if train_p.name.find("layer_norm") > -1: if train_p.name.find("layer_norm") > -1:
continue continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
cast_fp32_to_fp16(m_p_g[0], train_p, main_prog) append_cast_op(m_p_g[0], train_p, main_prog)
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwisw, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
with fluid.layers.Switch() as switch:
with switch.case(is_overall_finite):
should_incr_loss_scaling = fluid.layers.less_than(
incr_every_n_steps, num_good_steps + 1)
with fluid.layers.Switch() as switch1:
with switch1.case(should_incr_loss_scaling):
new_loss_scaling = prev_loss_scaling * incr_ratio
loss_scaling_is_finite = fluid.layers.isfinite(
new_loss_scaling)
with fluid.layers.Switch() as switch2:
with switch2.case(loss_scaling_is_finite):
fluid.layers.assign(new_loss_scaling,
prev_loss_scaling)
with switch2.default():
pass
fluid.layers.assign(zero_steps, num_good_steps)
fluid.layers.assign(zero_steps, num_bad_steps)
with switch1.default():
fluid.layers.increment(num_good_steps)
fluid.layers.assign(zero_steps, num_bad_steps)
with switch.default():
should_decr_loss_scaling = fluid.layers.less_than(
decr_every_n_nan_or_inf, num_bad_steps + 1)
with fluid.layers.Switch() as switch3:
with switch3.case(should_decr_loss_scaling):
new_loss_scaling = prev_loss_scaling * decr_ratio
static_loss_scaling = \
fluid.layers.fill_constant(shape=[1],
dtype='float32',
value=1.0)
less_than_one = fluid.layers.less_than(new_loss_scaling,
static_loss_scaling)
with fluid.layers.Switch() as switch4:
with switch4.case(less_than_one):
fluid.layers.assign(static_loss_scaling,
prev_loss_scaling)
with switch4.default():
fluid.layers.assign(new_loss_scaling,
prev_loss_scaling)
fluid.layers.assign(zero_steps, num_good_steps)
fluid.layers.assign(zero_steps, num_bad_steps)
with switch3.default():
fluid.layers.assign(zero_steps, num_good_steps)
fluid.layers.increment(num_bad_steps)
...@@ -12,27 +12,37 @@ ...@@ -12,27 +12,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os import os
import six import six
import ast import ast
import copy import copy
import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
log = logging.getLogger(__name__)
def cast_fp32_to_fp16(exe, main_program): def cast_fp32_to_fp16(exe, main_program):
print("Cast parameters to float16 data format.") log.info("Cast parameters to float16 data format.")
for param in main_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
if not param.name.endswith(".master"): if not param.name.endswith(".master"):
param_t = fluid.global_scope().find_var(param.name).get_tensor() param_t = fluid.global_scope().find_var(param.name).get_tensor()
data = np.array(param_t) data = np.array(param_t)
if param.name.find("layer_norm") == -1: if param.name.startswith("encoder_layer") \
and "layer_norm" not in param.name:
param_t.set(np.float16(data).view(np.uint16), exe.place) param_t.set(np.float16(data).view(np.uint16), exe.place)
master_param_var = fluid.global_scope().find_var(param.name +
".master") #load fp32
master_param_var = fluid.global_scope().find_var(param.name +
".master")
if master_param_var is not None: if master_param_var is not None:
master_param_var.get_tensor().set(data, exe.place) master_param_var.get_tensor().set(data, exe.place)
...@@ -40,7 +50,7 @@ def cast_fp32_to_fp16(exe, main_program): ...@@ -40,7 +50,7 @@ def cast_fp32_to_fp16(exe, main_program):
def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False): def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
assert os.path.exists( assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var): def existed_persitables(var):
if not fluid.io.is_persistable(var): if not fluid.io.is_persistable(var):
return False return False
...@@ -51,7 +61,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False): ...@@ -51,7 +61,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
init_checkpoint_path, init_checkpoint_path,
main_program=main_program, main_program=main_program,
predicate=existed_persitables) predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) log.info("Load model from {}".format(init_checkpoint_path))
if use_fp16: if use_fp16:
cast_fp32_to_fp16(exe, main_program) cast_fp32_to_fp16(exe, main_program)
...@@ -74,7 +84,7 @@ def init_pretraining_params(exe, ...@@ -74,7 +84,7 @@ def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program=main_program, main_program=main_program,
predicate=existed_params) predicate=existed_params)
print("Load pretraining parameters from {}.".format( log.info("Load pretraining parameters from {}.".format(
pretraining_params_path)) pretraining_params_path))
if use_fp16: if use_fp16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册