未验证 提交 1314771c 编写于 作者: M Meiyim 提交者: GitHub

Merge pull request #298 from Meiyim/dev4

Many updates
......@@ -111,6 +111,7 @@ Integrating both phrase information and named entity information enables the mod
## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz
......@@ -339,7 +340,7 @@ XNLI is a natural language inference dataset in 15 languages. It was jointly bui
*\*The DRCD dataset is converted from Traditional Chinese to Simplified Chinese based on tool: https://github.com/skydark/nstools/tree/master/zhtools*
\* *The pre-training data of ERNIE 1.0 BASE does not contain instances whose length exceeds 128, but other models is pre-trained with the instances whose length are 512. It causes poorer performance of ERNIE 1.0 BASE on long-text tasks. So We have released [ERNIE 1.0 Base(max-len-512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) on July 29th, 2019*
\* *The pre-training data of ERNIE 1.0 BASE does not contain instances whose length exceeds 128, but other models is pre-trained with the instances whose length are 512. It causes poorer performance of ERNIE 1.0 BASE on long-text tasks. So We have released [ERNIE 1.0 Base (max-len-512)](https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz) on July 29th, 2019*
......@@ -371,7 +372,7 @@ DRCD is an open domain Traditional Chinese machine reading comprehension (MRC) d
<tr>
<th><strong>Dataset</strong>
<br></th>
<th colspan="2"><center><strong>MSRA-NER(SIGHAN2006)</strong></center></th>
<th colspan="2"><center><strong>MSRA-NER (SIGHAN2006)</strong></center></th>
<tr>
<td rowspan="2">
<p>
......@@ -413,10 +414,10 @@ DRCD is an open domain Traditional Chinese machine reading comprehension (MRC) d
</tbody>
</table>
- **MSRA-NER(SIGHAN2006)**
- **MSRA-NER (SIGHAN2006)**
```text
MSRA-NER(SIGHAN2006) dataset is released by MSRA for recognizing the names of people, locations and organizations in text.
MSRA-NER (SIGHAN2006) dataset is released by MSRA for recognizing the names of people, locations and organizations in text.
```
#### Results on Sentiment Analysis Task
......@@ -622,7 +623,7 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u
- **BQ Corpus**
```text
BQ Corpus(Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536]
BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536]
```
......@@ -635,6 +636,7 @@ BQ Corpus(Bank Question corpus) is a Chinese corpus for sentence semantic equiva
* [Chinese Datasets](#chinese-datasets)
* [Fine-tuning](#fine-tuning)
* [Batchsize and GPU Settings](#batchsize-and-gpu-settings)
* [Multiprocessing and fp16 auto mix-precision finetune](#multiprocessing-and-fp16-auto-mix-precision-finetune)
* [Classification](#classification)
* [Single Sentence Classification Tasks](#single-sentence-classification-tasks)
* [Sentence Pair Classification Tasks](#sentence-pair-classification-tasks)
......@@ -705,14 +707,14 @@ In our experiments, we found that the batch size is important for different task
| Dataset | Batch Size | GPU |
| ------------ | --------------- | ------------------- |
| CoLA | 32 / 64(base) | 1 |
| SST-2 | 64 / 256(base) | 8 |
| CoLA | 32 / 64 (base) | 1 |
| SST-2 | 64 / 256 (base) | 8 |
| STS-B | 128 | 8 |
| QQP | 256 | 8 |
| MNLI | 256 / 512(base) | 8 |
| MNLI | 256 / 512 (base) | 8 |
| QNLI | 256 | 8 |
| RTE | 16 / 4(base) | 1 |
| MRPC | 16 / 32(base) | 2 |
| RTE | 16 / 4 (base) | 1 |
| MRPC | 16 / 32 (base) | 2 |
| WNLI | 8 | 1 |
| XNLI | 65536 (tokens) | 8 |
| CMRC2018 | 64 | 8 (large) / 4(base) |
......@@ -725,6 +727,17 @@ In our experiments, we found that the batch size is important for different task
\* *For MNLI, QNLI,we used 32GB V100, for other tasks we used 22GB P40*
### Multiprocessing and fp16 auto mix-precision finetune
multiprocessing finetuning can be simply enabled with `finetune_launch.py` in your finetune script.
with multiprocessing finetune paddle can fully utilize your CPU/GPU capacity to accelerate finetuning.
`finetune_launch.py` should place in front of your finetune command. make sure to provide number of process and device id per node by specifiying `--nproc_per_node` and `--selected_gpus`. Number of device ids should match `nproc_per_node` and `CUDA_VISIBLE_DEVICES`, and the indexing should start from 0.
fp16 finetuning can be simply enable by specifing `--use_fp16 true` in your training script (make sure you use have a Tensor Core device). ERNIE will cast computation op to fp16 precision, while maintain storage in fp32 precision. approximately 60% speedup is seen on XNLI finetuning.
dynamic loss scale is used to avoid gradient vanish.
### Classification
#### Single Sentence Classification Tasks
......
......@@ -371,10 +371,10 @@ DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从
</tbody>
</table>
- **MSRA-NER(SIGHAN2006)**
- **MSRA-NER (SIGHAN2006)**
```text
MSRA-NER(SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,包括人名、地名、机构名。
MSRA-NER (SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,包括人名、地名、机构名。
```
......@@ -640,6 +640,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
* [英文数据](#英文数据)
* [Fine-tuning 任务](#fine-tuning-任务)
* [运行参数配置](#运行参数配置)
* [多进程训练与fp16混合精度](#多进程训练与fp16混合精度)
* [单句和句对分类任务](#单句和句对分类任务)
* [单句分类任务](#单句分类任务)
* [句对分类任务](#句对分类任务)
......@@ -720,8 +721,8 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
| MRPC | 16 / 32 (base) | 2 |
| WNLI | 8 | 1 |
| XNLI | 65536 (tokens) | 8 |
| CMRC2018 | 64 | 8 (large) / 4(base) |
| DRCD | 64 | 8 (large) / 4(base) |
| CMRC2018 | 64 | 8 (large) / 4 (base) |
| DRCD | 64 | 8 (large) / 4 (base) |
| MSRA-NER(SIGHAN 2006) | 16 | 1 |
| ChnSentiCorp | 24 | 1 |
| LCQMC | 32 | 1 |
......@@ -731,6 +732,12 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
\* *MNLI 和 QNLI 的任务中,使用了 32 GB 显存的 V100。除此之外的显卡皆为22 GB 的 P40。*
### 多进程训练与fp16混合精度
使用`finetune_launch.py`脚本来启动多进程训练 。多进程训练可以提升充分利用多核CPU/多卡GPU 的能力来加速finetune过程。
`finetune_launch.py` 需要放在原来finetune脚本前面, 同时指定每个节点的进程数`--nproc_per_node`, 以及每个节点上的gpu卡号`--selected_gpus`, 一般数量与进程数, `CUDA_VISIBLE_DEVICES`相同且从0开始编号 (参考`script/zh_task/ernie_base/run_xnli.sh`)
只需在训练脚本中加入`--use_fp16 true`即可启用fp16混合精度训练(确保您的硬件支持Tensor Core技术)。ERNIE会将计算Op转换成fp16精度,同时仍然使用fp32精度存储参数。ERNIE使用动态loss scale来避免梯度消失。在XNLI任务上可以观察到大约60%加速。
### 单句和句对分类任务
......
......@@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference by """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import time
......@@ -39,7 +40,7 @@ from reader.task_reader import ClassifyReader
from model.ernie import ErnieConfig
from finetune.classifier import create_model
from utils.args import ArgumentGroup, print_arguments
from utils.args import print_arguments, check_cuda, prepare_logger
from utils.init import init_pretraining_params
from finetune_args import parser
......@@ -66,6 +67,7 @@ run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for trai
run_type_g.add_arg("do_prediction", bool, True, "Whether to do prediction on test set.")
args = parser.parse_args()
log = logging.getLogger()
# yapf: enable.
def main(args):
......@@ -113,7 +115,7 @@ def main(args):
_, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/'))
dir_name = ckpt_dir + '_inference_model'
model_path = os.path.join(args.save_inference_model_path, dir_name)
print("save inference model to %s" % model_path)
log.info("save inference model to %s" % model_path)
fluid.io.save_inference_model(
model_path,
feed_target_names, [probs],
......@@ -125,7 +127,7 @@ def main(args):
#config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, ""))
config = AnalysisConfig(model_path)
if not args.use_cuda:
print("disable gpu")
log.info("disable gpu")
config.disable_gpu()
# Create PaddlePredictor
......@@ -137,7 +139,7 @@ def main(args):
epoch=1,
shuffle=False)
print("-------------- prediction results --------------")
log.info("-------------- prediction results --------------")
np.set_printoptions(precision=4, suppress=True)
index = 0
total_time = 0
......@@ -156,14 +158,14 @@ def main(args):
# parse outputs
output = outputs[0]
print(output.name)
log.info(output.name)
output_data = output.data.float_data()
#assert len(output_data) == args.num_labels * args.batch_size
batch_result = np.array(output_data).reshape((-1, args.num_labels))
for single_example_probs in batch_result:
print("{} example\t{}".format(index, single_example_probs))
log.info("{} example\t{}".format(index, single_example_probs))
index += 1
print("qps:{}\ttotal_time:{}\ttotal_example:{}\tbatch_size:{}".format(index/total_time, total_time, index, args.batch_size))
log.info("qps:{}\ttotal_time:{}\ttotal_example:{}\tbatch_size:{}".format(index/total_time, total_time, index, args.batch_size))
def array2tensor(ndarray):
......@@ -183,5 +185,6 @@ def array2tensor(ndarray):
return tensor
if __name__ == '__main__':
prepare_logger(log)
print_arguments(args)
main(args)
......@@ -129,8 +129,6 @@ def main(args):
pyreader, graph_vars = create_model(
args, pyreader_name='reader', ernie_config=ernie_config)
fluid.memory_optimize(input_program=infer_program)
infer_program = infer_program.clone(for_test=True)
exe.run(startup_prog)
......
......@@ -16,8 +16,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import time
import logging
import numpy as np
from scipy.stats import pearsonr, spearmanr
......@@ -26,6 +29,7 @@ import paddle.fluid as fluid
from model.ernie import ErnieModel
log = logging.getLogger(__name__)
def create_model(args,
pyreader_name,
......
......@@ -16,12 +16,15 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import time
import numpy as np
import os
import math
import json
import logging
import collections
import six
......@@ -34,6 +37,8 @@ from model.ernie import ErnieModel
import tokenization
log = logging.getLogger(__name__)
def create_model(args, pyreader_name, ernie_config, is_training):
pyreader = fluid.layers.py_reader(
capacity=50,
......@@ -151,7 +156,7 @@ def evaluate(exe,
program=test_program, fetch_list=fetch_list)
for idx in range(np_unique_ids.shape[0]):
if len(all_results) % 1000 == 0:
print("Processing example: %d" % len(all_results))
log.info("Processing example: %d" % len(all_results))
unique_id = int(np_unique_ids[idx])
start_logits = [float(x) for x in np_start_logits[idx].flat]
end_logits = [float(x) for x in np_end_logits[idx].flat]
......@@ -179,7 +184,7 @@ def evaluate(exe,
time_end = time.time()
elapsed_time = time_end - time_begin
print(
log.info(
"[%s evaluation] em: %f, f1: %f, avg: %f, questions: %d, elapsed time: %f"
% (eval_phase, em, f1, avg, total, elapsed_time))
......@@ -188,8 +193,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file):
"""Write final predictions to the json file and log-odds of null if needed."""
print("Writing predictions to: %s" % (output_prediction_file))
print("Writing nbest to: %s" % (output_nbest_file))
log.info("Writing predictions to: %s" % (output_prediction_file))
log.info("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
......
......@@ -15,6 +15,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import time
......@@ -23,12 +26,14 @@ import numpy as np
import multiprocessing
import paddle
import logging
import paddle.fluid as fluid
from six.moves import xrange
from model.ernie import ErnieModel
log = logging.getLogger(__name__)
def create_model(args, pyreader_name, ernie_config, is_prediction=False):
pyreader = fluid.layers.py_reader(
......@@ -70,9 +75,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
initializer=fluid.initializer.Constant(0.)))
infers = fluid.layers.argmax(logits, axis=2)
ret_labels = fluid.layers.reshape(x=labels, shape=[-1, 1])
ret_infers = fluid.layers.reshape(x=infers, shape=[-1, 1])
lod_labels = fluid.layers.sequence_unpad(labels, seq_lens)
lod_infers = fluid.layers.sequence_unpad(infers, seq_lens)
......@@ -92,18 +95,14 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
ce_loss = ce_loss * input_mask
loss = fluid.layers.mean(x=ce_loss)
if args.use_fp16 and args.loss_scaling > 1.0:
loss *= args.loss_scaling
graph_vars = {
"inputs": src_ids,
"loss": loss,
"probs": probs,
"labels": ret_labels,
"infers": ret_infers,
"seqlen": seq_lens,
"num_infer": num_infer,
"num_label": num_label,
"num_correct": num_correct,
"seq_lens": seq_lens
}
for k, v in graph_vars.items():
......@@ -112,91 +111,6 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
return pyreader, graph_vars
def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
def extract_bio_chunk(seq):
chunks = []
cur_chunk = None
null_index = tag_num - 1
for index in xrange(len(seq)):
tag = seq[index]
tag_type = tag // 2
tag_pos = tag % 2
if tag == null_index:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = None
continue
if tag_pos == 0:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = {}
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else:
if cur_chunk is None:
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
continue
if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1
else:
chunks.append(cur_chunk)
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
if cur_chunk is not None:
chunks.append(cur_chunk)
return chunks
null_index = tag_num - 1
num_label = 0
num_infer = 0
num_correct = 0
labels = np_labels.reshape([-1]).astype(np.int32).tolist()
infers = np_infers.reshape([-1]).astype(np.int32).tolist()
all_lens = np_lens.reshape([dev_count, -1]).astype(np.int32).tolist()
base_index = 0
for dev_index in xrange(dev_count):
lens = all_lens[dev_index]
max_len = 0
for l in lens:
max_len = max(max_len, l)
for i in xrange(len(lens)):
seq_st = base_index + i * max_len + 1
seq_en = seq_st + (lens[i] - 2)
infer_chunks = extract_bio_chunk(infers[seq_st:seq_en])
label_chunks = extract_bio_chunk(labels[seq_st:seq_en])
num_infer += len(infer_chunks)
num_label += len(label_chunks)
infer_index = 0
label_index = 0
while label_index < len(label_chunks) \
and infer_index < len(infer_chunks):
if infer_chunks[infer_index]["st"] \
< label_chunks[label_index]["st"]:
infer_index += 1
elif infer_chunks[infer_index]["st"] \
> label_chunks[label_index]["st"]:
label_index += 1
else:
if infer_chunks[infer_index]["en"] \
== label_chunks[label_index]["en"] \
and infer_chunks[infer_index]["type"] \
== label_chunks[label_index]["type"]:
num_correct += 1
infer_index += 1
label_index += 1
base_index += max_len * len(lens)
return num_label, num_infer, num_correct
def calculate_f1(num_label, num_infer, num_correct):
if num_infer == 0:
precision = 0.0
......@@ -220,53 +134,85 @@ def evaluate(exe,
pyreader,
graph_vars,
tag_num,
eval_phase,
dev_count=1):
fetch_list = [
graph_vars["num_infer"].name, graph_vars["num_label"].name,
graph_vars["num_correct"].name
]
if eval_phase == "train":
fetch_list.append(graph_vars["loss"].name)
if "learning_rate" in graph_vars:
fetch_list.append(graph_vars["learning_rate"].name)
outputs = exe.run(fetch_list=fetch_list)
np_num_infer, np_num_label, np_num_correct, np_loss = outputs[:4]
num_label = np.sum(np_num_label)
num_infer = np.sum(np_num_infer)
num_correct = np.sum(np_num_correct)
precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct)
rets = {
"precision": precision,
"recall": recall,
"f1": f1,
"loss": np.mean(np_loss)
}
if "learning_rate" in graph_vars:
rets["lr"] = float(outputs[4][0])
return rets
total_label, total_infer, total_correct = 0.0, 0.0, 0.0
time_begin = time.time()
pyreader.start()
while True:
try:
np_num_infer, np_num_label, np_num_correct = exe.run(program=program,
fetch_list=fetch_list)
total_infer += np.sum(np_num_infer)
total_label += np.sum(np_num_label)
total_correct += np.sum(np_num_correct)
except fluid.core.EOFException:
pyreader.reset()
break
precision, recall, f1 = calculate_f1(total_label, total_infer,
total_correct)
time_end = time.time()
return \
"[evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s" \
% (f1, precision, recall, time_end - time_begin)
def chunk_predict(np_inputs, np_probs, np_lens, dev_count=1):
inputs = np_inputs.reshape([-1]).astype(np.int32)
probs = np_probs.reshape([-1, np_probs.shape[-1]])
all_lens = np_lens.reshape([dev_count, -1]).astype(np.int32).tolist()
base_index = 0
out = []
for dev_index in xrange(dev_count):
lens = all_lens[dev_index]
max_len = 0
for l in lens:
max_len = max(max_len, l)
for i in xrange(len(lens)):
seq_st = base_index + i * max_len + 1
seq_en = seq_st + (lens[i] - 2)
prob = probs[seq_st:seq_en, :]
infers = np.argmax(probs, -1)
out.append((
inputs[seq_st:seq_en].tolist(),
infers.tolist(),
probs.tolist()))
base_index += max_len * len(lens)
return out
def predict(exe,
test_program,
test_pyreader,
graph_vars,
dev_count=1):
fetch_list = [
graph_vars["inputs"].name,
graph_vars["probs"].name,
graph_vars["seqlen"].name,
graph_vars["probs"].name,
]
test_pyreader.start()
res = []
while True:
try:
inputs, probs, np_lens, np_probs = exe.run(program=test_program,
fetch_list=fetch_list)
r = chunk_predict(inputs, probs, np_lens, dev_count)
res += r
except fluid.core.EOFException:
test_pyreader.reset()
break
log.info(len(res))
return res
else:
total_label, total_infer, total_correct = 0.0, 0.0, 0.0
time_begin = time.time()
pyreader.start()
while True:
try:
np_num_infer, np_num_label, np_num_correct = exe.run(program=program,
fetch_list=fetch_list)
total_infer += np.sum(np_num_infer)
total_label += np.sum(np_num_label)
total_correct += np.sum(np_num_correct)
except fluid.core.EOFException:
pyreader.reset()
break
precision, recall, f1 = calculate_f1(total_label, total_infer,
total_correct)
time_end = time.time()
print(
"[%s evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s"
% (eval_phase, f1, precision, recall, time_end - time_begin))
......@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import time
......@@ -47,10 +49,21 @@ train_g.add_arg("warmup_proportion", float, 0.1,
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("loss_scaling", float, 1.0,
train_g.add_arg("use_dynamic_loss_scaling", bool, True, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 102400,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("test_save", str, "test_result", "test_save")
train_g.add_arg("test_save", str, "./checkpoints/test_result", "test_save")
train_g.add_arg("metric", str, "simple_accuracy", "metric")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
......@@ -86,6 +99,7 @@ data_g.add_arg("chunk_scheme", type=str, default="IOB", choices=["IO", "IOB", "
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
......
......@@ -16,14 +16,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import json
import six
import logging
import paddle.fluid as fluid
from io import open
from model.transformer_encoder import encoder, pre_process_layer
log = logging.getLogger(__name__)
class ErnieConfig(object):
def __init__(self, config_path):
......@@ -31,7 +35,7 @@ class ErnieConfig(object):
def _parse(self, config_path):
try:
with open(config_path) as json_file:
with open(config_path, 'r', encoding='utf8') as json_file:
config_dict = json.load(json_file)
except Exception:
raise IOError("Error in parsing Ernie model config file '%s'" %
......@@ -44,8 +48,8 @@ class ErnieConfig(object):
def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
class ErnieModel(object):
......@@ -102,7 +106,7 @@ class ErnieModel(object):
param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
position_emb_out = fluid.layers.embedding(
input=position_ids,
size=[self._max_position_seq_len, self._emb_size],
......@@ -163,6 +167,10 @@ class ErnieModel(object):
postprocess_cmd="dan",
param_initializer=self._param_initializer,
name='encoder')
if self._dtype == "float16":
self._enc_out = fluid.layers.cast(
x=self._enc_out, dtype=self._emb_dtype)
def get_sequence_output(self):
return self._enc_out
......@@ -171,9 +179,6 @@ class ErnieModel(object):
"""Get the first feature of each sequence for classification"""
next_sent_feat = fluid.layers.slice(
input=self._enc_out, axes=[1], starts=[0], ends=[1])
if self._dtype == "float16":
next_sent_feat = fluid.layers.cast(
x=next_sent_feat, dtype=self._emb_dtype)
next_sent_feat = fluid.layers.fc(
input=next_sent_feat,
size=self._emb_size,
......@@ -194,8 +199,6 @@ class ErnieModel(object):
x=self._enc_out, shape=[-1, self._emb_size])
# extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
if self._dtype == "float16":
mask_feat = fluid.layers.cast(x=mask_feat, dtype=self._emb_dtype)
# transform: fc
mask_trans_feat = fluid.layers.fc(
......@@ -206,7 +209,7 @@ class ErnieModel(object):
name='mask_lm_trans_fc.w_0',
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = fluid.layers.layer_norm(
mask_trans_feat,
......
......@@ -16,14 +16,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import json
import logging
import six
import paddle.fluid as fluid
from io import open
from model.transformer_encoder import encoder, pre_process_layer
log = logging.getLogger(__name__)
class ErnieConfig(object):
def __init__(self, config_path):
......@@ -31,7 +35,7 @@ class ErnieConfig(object):
def _parse(self, config_path):
try:
with open(config_path) as json_file:
with open(config_path, 'r', encoding='utf8') as json_file:
config_dict = json.load(json_file)
except Exception:
raise IOError("Error in parsing Ernie model config file '%s'" %
......@@ -44,8 +48,8 @@ class ErnieConfig(object):
def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
class ErnieModel(object):
......
......@@ -16,10 +16,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import numpy as np
import paddle.fluid as fluid
from utils.fp16 import create_master_params_grads, master_param_to_train_param
from utils.fp16 import create_master_params_grads, master_param_to_train_param, apply_dynamic_loss_scaling
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
......@@ -101,7 +104,7 @@ def optimization(loss,
return False
param_list = dict()
loss_scaling = fluid.layers.create_global_var(
name=fluid.unique_name.generate("loss_scaling"),
shape=[1],
......
......@@ -42,8 +42,18 @@ train_g.add_arg("warmup_steps", int, 5000, "Total steps to perform wa
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("loss_scaling", float, 1.0,
train_g.add_arg("use_dynamic_loss_scaling", bool, True, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 102400,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
......
......@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import numpy as np
......@@ -36,8 +38,10 @@ class ErnieDataReader(object):
filelist,
vocab_path,
batch_size=4096,
in_tokens=True,
max_seq_len=512,
shuffle_files=True,
random_seed=1,
epoch=100,
voc_size=0,
is_test=False,
......@@ -46,6 +50,8 @@ class ErnieDataReader(object):
self.vocab = self.load_vocab(vocab_path)
self.filelist = filelist
self.batch_size = batch_size
self.in_tokens = in_tokens
self.random_seed = random_seed
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
......@@ -60,12 +66,42 @@ class ErnieDataReader(object):
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
self.generate_neg_sample = generate_neg_sample
assert self.batch_size > 100, "Current batch size means total token's number, \
it should not be set to too small number."
self.trainer_id = 0
self.trainer_nums = 1
self.files = open(filelist).readlines()
self.total_file = len(self.files)
if self.is_test:
self.epoch = 1
self.shuffle_files = False
self.global_rng = np.random.RandomState(random_seed)
if self.shuffle_files:
if os.getenv("PADDLE_TRAINER_ID"):
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
if os.getenv("PADDLE_NODES_NUM"):
self.trainer_nums = int(os.getenv("PADDLE_TRAINERS_NUM"))
#renew total_file
self.total_file = len(self.files) // self.trainer_nums * self.trainer_nums
if len(self.files) < self.trainer_nums:
raise RuntimeError('not enouph train file to shard, file:%d num_trainer:%d' % (len(self.files), self.trainer_nums))
tmp_files = []
for each in range(epoch):
each_files = self.files
self.global_rng.shuffle(each_files)
tmp_files += each_files
self.files = tmp_files
#renew epochs
self.epoch = len(self.files) // self.total_file * self.total_file
assert self.total_file > 0, \
"[Error] data_dir is empty or less than %d" % self.trainer_nums
if self.in_tokens:
assert self.batch_size > 100, "Current batch size means total token's number, \
it should not be set to too small number."
def get_progress(self):
"""return current progress of traning data
......@@ -75,13 +111,16 @@ class ErnieDataReader(object):
def parse_line(self, line, max_seq_len=512):
""" parse one line to token_ids, sentence_ids, pos_ids, label
"""
line = line.strip().decode().split(";")
assert len(line) == 5, "One sample must have 5 fields!"
line = line.strip().split(";")
assert len(line) == 5, \
"One sample must have %d fields!" % 5
(token_ids, sent_ids, pos_ids, seg_labels, label) = line
token_ids = [int(token) for token in token_ids.split(" ")]
sent_ids = [int(token) for token in sent_ids.split(" ")]
pos_ids = [int(token) for token in pos_ids.split(" ")]
seg_labels = [int(seg_label) for seg_label in seg_labels.split(" ")]
assert len(token_ids) == len(sent_ids) == len(pos_ids) == len(
seg_labels
), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids) == len(seg_labels)"
......@@ -94,6 +133,7 @@ class ErnieDataReader(object):
assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file
with gzip.open(file, "rb") as f:
for line in f:
line = line.decode('utf8')
parsed_line = self.parse_line(
line, max_seq_len=self.max_seq_len)
if parsed_line is None:
......@@ -232,35 +272,63 @@ class ErnieDataReader(object):
print("miss_num:%d\tideal_total_sample_num:%d\tmiss_rate:%f" %
(num_total_miss, pos_sample_num * 2,
num_total_miss / (pos_sample_num * 2)))
def shuffle_samples(self, sample_generator, buffer=1000):
samples = []
try:
while True:
while len(samples) < buffer:
sample = next(sample_generator)
samples.append(sample)
np.random.shuffle(samples)
for sample in samples:
yield sample
samples = []
except StopIteration:
print("stopiteration: reach end of file")
if len(samples) == 0:
yield None
else:
np.random.shuffle(samples)
for sample in samples:
yield sample
def data_generator(self):
"""
data_generator
"""
files = open(self.filelist).readlines()
self.total_file = len(files)
assert self.total_file > 0, "[Error] data_dir is empty"
def wrapper():
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
files = self.files
#during training, data are sliced by trainers
if self.shuffle_files:
np.random.shuffle(files)
for index, file in enumerate(files):
file, mask_word_prob = file.strip().split("\t")
start = epoch * self.total_file
end = start + self.total_file
files = [file_ for index, file_ in enumerate(self.files[start:end]) \
if index % self.trainer_nums == self.trainer_id]
for index, file_ in enumerate(files):
file_, mask_word_prob = file_.strip().split("\t")
mask_word = (np.random.random() < float(mask_word_prob))
self.current_file_index = index + 1
self.current_file = file
self.current_file_index = (index + 1) * self.trainer_nums
self.current_file = file_
if mask_word:
self.mask_type = "mask_word"
else:
self.mask_type = "mask_char"
sample_generator = self.read_file(file)
if not self.is_test and self.generate_neg_sample:
sample_generator = self.mixin_negtive_samples(
sample_generator)
sample_generator = self.read_file(file_)
if not self.is_test:
if self.generate_neg_sample:
sample_generator = self.mixin_negtive_samples(
sample_generator)
else:
#shuffle buffered sample
sample_generator = self.shuffle_samples(
sample_generator)
for sample in sample_generator:
if sample is None:
continue
......@@ -272,7 +340,11 @@ class ErnieDataReader(object):
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label, seg_labels, mask_word = parsed_line
max_len = max(max_len, len(token_ids))
if (len(batch) + 1) * max_len <= batch_size:
if self.in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
......
......@@ -11,18 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import sys
import os
import csv
import json
import random
import logging
import numpy as np
import six
from io import open
from collections import namedtuple
import tokenization
from batching import pad_batch_data
log = logging.getLogger(__name__)
if six.PY3:
from itertools import accumulate
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
def csv_reader(fd, delimiter='\t'):
def gen():
for i in fd:
slots = i.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen()
class BaseReader(object):
def __init__(self,
vocab_path,
......@@ -58,7 +86,7 @@ class BaseReader(object):
self.num_examples = 0
if label_map_config:
with open(label_map_config) as f:
with open(label_map_config, encoding='utf8') as f:
self.label_map = json.load(f)
else:
self.label_map = None
......@@ -69,8 +97,8 @@ class BaseReader(object):
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
with open(input_file, 'r', encoding='utf8') as f:
reader = csv_reader(f)
headers = next(reader)
Example = namedtuple('Example', headers)
......@@ -225,6 +253,12 @@ class BaseReader(object):
phase=None):
examples = self._read_tsv(input_file)
if phase == 'train':
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
examples = examples[trainer_id: (len(examples) //trainer_num) * trainer_num : trainer_num]
log.info('apply sharding %d/%d' % (trainer_id, trainer_num))
def wrapper():
all_dev_batches = []
for epoch_index in range(epoch):
......@@ -242,15 +276,21 @@ class BaseReader(object):
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper
def f():
try:
for i in wrapper():
yield i
except Exception as e:
import traceback
traceback.print_exc()
return f
class ClassifyReader(BaseReader):
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
with open(input_file, 'r', encoding='utf8') as f:
reader = csv_reader(f)
headers = next(reader)
text_indices = [
index for index, h in enumerate(headers) if h != "label"
......@@ -472,7 +512,7 @@ class MRCReader(BaseReader):
def _read_json(self, input_file, is_training):
examples = []
with open(input_file, "r") as f:
with open(input_file, "r", encoding='utf8') as f:
input_data = json.load(f)["data"]
for entry in input_data:
for paragraph in entry["paragraphs"]:
......@@ -507,7 +547,7 @@ class MRCReader(BaseReader):
actual_text = " ".join(doc_tokens[start_pos:(end_pos
+ 1)])
if actual_text.find(orig_answer_text) == -1:
print("Could not find answer: '%s' vs. '%s'",
log.info("Could not find answer: '%s' vs. '%s'",
actual_text, orig_answer_text)
continue
else:
......
......@@ -16,9 +16,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import time
import logging
import multiprocessing
# NOTE(paddle-dev): All of these flags should be
......@@ -32,12 +35,13 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig
from finetune.classifier import create_model, evaluate, predict
from optimization import optimization
from utils.args import print_arguments, check_cuda
from utils.args import print_arguments, check_cuda, prepare_logger
from utils.init import init_pretraining_params, init_checkpoint
from utils.cards import get_cards
from finetune_args import parser
args = parser.parse_args()
log = logging.getLogger()
def main(args):
......@@ -45,8 +49,9 @@ def main(args):
ernie_config.print_config()
if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = fluid.core.get_cuda_device_count()
dev_list = fluid.cuda_places()
place = dev_list[0]
dev_count = len(dev_list)
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
......@@ -95,10 +100,10 @@ def main(args):
max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count
warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Device count: %d" % dev_count)
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps)
log.info("Device count: %d" % dev_count)
log.info("Num train examples: %d" % num_train_examples)
log.info("Max train steps: %d" % max_train_steps)
log.info("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program()
if args.random_seed is not None and args.enable_ce:
......@@ -121,7 +126,13 @@ def main(args):
startup_prog=startup_prog,
weight_decay=args.weight_decay,
scheduler=args.lr_scheduler,
use_fp16=args.use_fp16)
use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
if args.verbose:
if args.in_tokens:
......@@ -131,7 +142,7 @@ def main(args):
else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" %
log.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit))
if args.do_val or args.do_test:
......@@ -148,11 +159,36 @@ def main(args):
test_prog = test_prog.clone(for_test=True)
nccl2_num_trainers = 1
nccl2_trainer_id = 0
if args.is_distributed:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
# prepare nccl2 env.
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_program if args.do_train else test_prog,
startup_program=startup_prog)
nccl2_num_trainers = trainers_num
nccl2_trainer_id = trainer_id
exe = fluid.Executor(place)
exe.run(startup_prog)
if args.do_train:
if args.init_checkpoint and args.init_pretraining_params:
print(
log.warning(
"WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint:
......@@ -236,14 +272,14 @@ def main(args):
verbose += "learning rate: %f" % (
outputs["learning_rate"]
if warmup_steps > 0 else args.learning_rate)
print(verbose)
log.info(verbose)
current_example, current_epoch = reader.get_train_progress()
time_end = time.time()
used_time = time_end - time_begin
if args.is_classify:
print(
log.info(
"epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
"ave acc: %f, speed: %f steps/s" %
(current_epoch, current_example, num_train_examples,
......@@ -252,7 +288,7 @@ def main(args):
ce_info.append(
[outputs["loss"], outputs["accuracy"], used_time])
if args.is_regression:
print(
log.info(
"epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
" speed: %f steps/s" %
(current_epoch, current_example, num_train_examples,
......@@ -260,22 +296,23 @@ def main(args):
args.skip_steps / used_time))
time_begin = time.time()
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
if nccl2_trainer_id == 0:
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
if steps % args.validation_steps == 0 or last_epoch != current_epoch:
# evaluate dev set
if args.do_val:
evaluate_wrapper(args, reader, exe, test_prog,
test_pyreader, graph_vars,
current_epoch, steps)
if steps % args.validation_steps == 0 or last_epoch != current_epoch:
# evaluate dev set
if args.do_val:
evaluate_wrapper(args, reader, exe, test_prog,
test_pyreader, graph_vars,
current_epoch, steps)
if args.do_test:
predict_wrapper(args, reader, exe, test_prog,
test_pyreader, graph_vars,
current_epoch, steps)
if args.do_test:
predict_wrapper(args, reader, exe, test_prog,
test_pyreader, graph_vars,
current_epoch, steps)
if last_epoch != current_epoch:
last_epoch = current_epoch
......@@ -295,10 +332,10 @@ def main(args):
ce_acc = ce_info[-2][1]
ce_time = ce_info[-2][2]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time))
print("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss))
print("kpis\ttrain_acc_card%s\t%f" % (card_num, ce_acc))
log.info("ce info error")
log.info("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time))
log.info("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss))
log.info("kpis\ttrain_acc_card%s\t%f" % (card_num, ce_acc))
# final eval on dev set
if args.do_val:
......@@ -320,7 +357,7 @@ def main(args):
dev_count=1,
shuffle=False))
print("Final diagnostic")
log.info("Final diagnostic")
qids, preds, probs = predict(
test_exe,
test_prog,
......@@ -334,7 +371,7 @@ def main(args):
for id, s, p in zip(qids, preds, probs):
f.write('{}\t{}\t{}\n'.format(id, s, p))
print("Done final diagnostic, saving to {}".format(
log.info("Done final diagnostic, saving to {}".format(
args.diagnostic_save))
......@@ -349,7 +386,7 @@ def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
epoch=1,
dev_count=1,
shuffle=False))
print("validation result of dataset {}:".format(ds))
log.info("validation result of dataset {}:".format(ds))
evaluate_info = evaluate(
exe,
test_prog,
......@@ -359,7 +396,7 @@ def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
metric=args.metric,
is_classify=args.is_classify,
is_regression=args.is_regression)
print(evaluate_info + ', file: {}, epoch: {}, steps: {}'.format(
log.info(evaluate_info + ', file: {}, epoch: {}, steps: {}'.format(
ds, epoch, steps))
......@@ -379,7 +416,7 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
shuffle=False))
save_path = save_f + '.' + str(epoch) + '.' + str(steps)
print("testing {}, save to {}".format(test_f, save_path))
log.info("testing {}, save to {}".format(test_f, save_path))
qids, preds, probs = predict(
exe,
test_prog,
......@@ -391,6 +428,9 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
else:
log.warning('save dir exsits: %s, will skip saving' % save_dir)
with open(save_path, 'w') as f:
for id, s, p in zip(qids, preds, probs):
......@@ -398,6 +438,7 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
if __name__ == '__main__':
prepare_logger(log)
print_arguments(args)
check_cuda(args.use_cuda)
main(args)
......@@ -16,9 +16,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import time
import logging
import multiprocessing
# NOTE(paddle-dev): All of these flags should be
......@@ -32,11 +34,12 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig
from finetune.mrc import create_model, evaluate
from optimization import optimization
from utils.args import print_arguments
from utils.args import print_arguments, prepare_logger
from utils.init import init_pretraining_params, init_checkpoint
from finetune_args import parser
args = parser.parse_args()
log = logging.getLogger()
def main(args):
......@@ -44,8 +47,9 @@ def main(args):
ernie_config.print_config()
if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = fluid.core.get_cuda_device_count()
dev_list = fluid.cuda_places()
place = dev_list[0]
dev_count = len(dev_list)
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
......@@ -70,6 +74,8 @@ def main(args):
raise ValueError("For args `do_train`, `do_val` and `do_test`, at "
"least one of them must be True.")
if args.do_test:
assert args.test_save is not None
startup_prog = fluid.Program()
if args.random_seed is not None:
startup_prog.random_seed = args.random_seed
......@@ -77,11 +83,12 @@ def main(args):
if args.predict_batch_size == None:
args.predict_batch_size = args.batch_size
if args.do_train:
trainers_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
train_data_generator = reader.data_generator(
input_file=args.train_set,
batch_size=args.batch_size,
epoch=args.epoch,
dev_count=dev_count,
dev_count=trainers_num,
shuffle=True,
phase="train")
......@@ -94,10 +101,10 @@ def main(args):
max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count
warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Device count: %d" % dev_count)
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps)
log.info("Device count: %d" % dev_count)
log.info("Num train examples: %d" % num_train_examples)
log.info("Max train steps: %d" % max_train_steps)
log.info("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program()
......@@ -108,7 +115,7 @@ def main(args):
pyreader_name='train_reader',
ernie_config=ernie_config,
is_training=True)
scheduled_lr, loss_scaling = optimization(
scheduled_lr, _ = optimization(
loss=graph_vars["loss"],
warmup_steps=warmup_steps,
num_train_steps=max_train_steps,
......@@ -117,7 +124,13 @@ def main(args):
startup_prog=startup_prog,
weight_decay=args.weight_decay,
scheduler=args.lr_scheduler,
use_fp16=args.use_fp16)
use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
if args.verbose:
if args.in_tokens:
......@@ -127,7 +140,7 @@ def main(args):
else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" %
log.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit))
if args.do_val or args.do_test:
......@@ -144,11 +157,36 @@ def main(args):
nccl2_num_trainers = 1
nccl2_trainer_id = 0
if args.is_distributed:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
# prepare nccl2 env.
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_program if args.do_train else test_prog,
startup_program=startup_prog)
nccl2_num_trainers = trainers_num
nccl2_trainer_id = trainer_id
exe = fluid.Executor(place)
exe.run(startup_prog)
if args.do_train:
if args.init_checkpoint and args.init_pretraining_params:
print(
log.warning(
"WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint:
......@@ -214,12 +252,12 @@ def main(args):
verbose += "learning rate: %f" % (
outputs["learning_rate"]
if warmup_steps > 0 else args.learning_rate)
print(verbose)
log.info(verbose)
current_example, current_epoch = reader.get_train_progress()
time_end = time.time()
used_time = time_end - time_begin
print("epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
log.info("epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
"speed: %f steps/s" %
(current_epoch, current_example, num_train_examples,
steps, outputs["loss"], args.skip_steps / used_time))
......@@ -277,7 +315,7 @@ def main(args):
# final eval on dev set
if args.do_val:
print("Final validation result:")
log.info("Final validation result:")
test_pyreader.decorate_tensor_provider(
reader.data_generator(
args.dev_set,
......@@ -298,7 +336,7 @@ def main(args):
# final eval on test set
if args.do_test:
print("Final test result:")
log.info("Final test result:")
test_pyreader.decorate_tensor_provider(
reader.data_generator(
args.test_set,
......@@ -319,6 +357,8 @@ def main(args):
if __name__ == '__main__':
prepare_logger(log)
print_arguments(args)
while True:
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
......
......@@ -16,10 +16,15 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import time
import six
import logging
import multiprocessing
from io import open
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
......@@ -32,11 +37,12 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig
from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint
from utils.args import print_arguments, check_cuda
from finetune.sequence_label import create_model, evaluate
from utils.args import print_arguments, check_cuda, prepare_logger
from finetune.sequence_label import create_model, evaluate, predict, calculate_f1
from finetune_args import parser
args = parser.parse_args()
log = logging.getLogger()
def main(args):
......@@ -44,12 +50,12 @@ def main(args):
ernie_config.print_config()
if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = fluid.core.get_cuda_device_count()
dev_list = fluid.cuda_places()
place = dev_list[0]
dev_count = len(dev_list)
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
reader = task_reader.SequenceLabelReader(
vocab_path=args.vocab_path,
......@@ -85,10 +91,10 @@ def main(args):
max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count
warmup_steps = int(max_train_steps * args.warmup_proportion)
print("Device count: %d" % dev_count)
print("Num train examples: %d" % num_train_examples)
print("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps)
log.info("Device count: %d" % dev_count)
log.info("Num train examples: %d" % num_train_examples)
log.info("Max train steps: %d" % max_train_steps)
log.info("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program()
......@@ -107,7 +113,13 @@ def main(args):
startup_prog=startup_prog,
weight_decay=args.weight_decay,
scheduler=args.lr_scheduler,
use_fp16=args.use_fp16)
use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
if args.verbose:
if args.in_tokens:
......@@ -117,7 +129,7 @@ def main(args):
else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" %
log.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit))
if args.do_val or args.do_test:
......@@ -131,11 +143,38 @@ def main(args):
test_prog = test_prog.clone(for_test=True)
nccl2_num_trainers = 1
nccl2_trainer_id = 0
if args.is_distributed:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
# prepare nccl2 env.
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_program if args.do_train else test_prog,
startup_program=startup_prog)
nccl2_num_trainers = trainers_num
nccl2_trainer_id = trainer_id
exe = fluid.Executor(place)
exe.run(startup_prog)
if args.do_train:
if args.init_checkpoint and args.init_pretraining_params:
print(
log.info(
"WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint:
......@@ -171,7 +210,9 @@ def main(args):
use_cuda=args.use_cuda,
loss_name=graph_vars["loss"].name,
exec_strategy=exec_strategy,
main_program=train_program)
main_program=train_program,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
train_pyreader.decorate_tensor_provider(train_data_generator)
else:
......@@ -186,8 +227,7 @@ def main(args):
if args.do_train:
train_pyreader.start()
steps = 0
if warmup_steps > 0:
graph_vars["learning_rate"] = scheduled_lr
graph_vars["learning_rate"] = scheduled_lr
time_begin = time.time()
while True:
......@@ -196,54 +236,47 @@ def main(args):
if steps % args.skip_steps != 0:
train_exe.run(fetch_list=[])
else:
outputs = evaluate(train_exe, train_program, train_pyreader,
graph_vars, args.num_labels, "train",
dev_count)
fetch_list = [
graph_vars["num_infer"].name, graph_vars["num_label"].name,
graph_vars["num_correct"].name,
graph_vars["loss"].name,
graph_vars['learning_rate'].name,
]
out = train_exe.run(fetch_list=fetch_list)
num_infer, num_label, num_correct, np_loss, np_lr = out
lr = float(np_lr[0])
loss = np_loss.mean()
precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct)
if args.verbose:
verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size(
)
verbose += "learning rate: %f" % (
outputs["lr"]
if warmup_steps > 0 else args.learning_rate)
print(verbose)
log.info("train pyreader queue size: %d, learning rate: %f" % (train_pyreader.queue.size(),
lr if warmup_steps > 0 else args.learning_rate))
current_example, current_epoch = reader.get_train_progress()
time_end = time.time()
used_time = time_end - time_begin
print("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
log.info("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
"f1: %f, precision: %f, recall: %f, speed: %f steps/s"
% (current_epoch, current_example, num_train_examples,
steps, outputs["loss"], outputs["f1"],
outputs["precision"], outputs["recall"],
steps, loss, f1, precision, recall,
args.skip_steps / used_time))
time_begin = time.time()
if steps % args.save_steps == 0:
if nccl2_trainer_id == 0 and steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
if steps % args.validation_steps == 0:
if nccl2_trainer_id == 0 and steps % args.validation_steps == 0:
# evaluate dev set
if args.do_val:
test_pyreader.decorate_tensor_provider(
reader.data_generator(
args.dev_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, graph_vars,
args.num_labels, "dev")
evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
current_epoch, steps)
# evaluate test set
if args.do_test:
test_pyreader.decorate_tensor_provider(
reader.data_generator(
args.test_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, graph_vars,
args.num_labels, "test")
predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
current_epoch, steps)
except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps))
......@@ -252,31 +285,65 @@ def main(args):
break
# final eval on dev set
if args.do_val:
test_pyreader.decorate_tensor_provider(
reader.data_generator(
args.dev_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False))
print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels,
"dev")
if nccl2_trainer_id ==0 and args.do_val:
evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
current_epoch, 'final')
# final eval on test set
if args.do_test:
if nccl2_trainer_id == 0 and args.do_test:
predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
current_epoch, 'final')
def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps):
# evaluate dev set
for ds in args.dev_set.split(','): #single card eval
test_pyreader.decorate_tensor_provider(
reader.data_generator(
args.test_set,
batch_size=args.batch_size,
ds,
batch_size=args.predict_batch_size,
epoch=1,
dev_count=1,
shuffle=False))
print("Final test result:")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels,
"test")
log.info("validation result of dataset {}:".format(ds))
info = evaluate(exe, test_prog, test_pyreader, graph_vars,
args.num_labels)
log.info(info + ', file: {}, epoch: {}, steps: {}'.format(
ds, epoch, steps))
def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps):
test_sets = args.test_set.split(',')
save_dirs = args.test_save.split(',')
assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs))
for test_f, save_f in zip(test_sets, save_dirs):
test_pyreader.decorate_tensor_provider(reader.data_generator(
test_f,
batch_size=args.predict_batch_size,
epoch=1,
dev_count=1,
shuffle=False))
save_path = save_f + '.' + str(epoch) + '.' + str(steps)
log.info("testing {}, save to {}".format(test_f, save_path))
res = predict(exe, test_prog, test_pyreader, graph_vars, dev_count=1)
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
tokenizer = reader.tokenizer
rev_label_map = {v: k for k, v in six.iteritems(reader.label_map)}
with open(save_path, 'w', encoding='utf8') as f:
for id, s, p in res:
id = ' '.join(tokenizer.convert_ids_to_tokens(id))
p = ' '.join(['%.5f' % pp[ss] for ss, pp in zip(s, p)])
s = ' '.join([rev_label_map[ss]for ss in s])
f.write('{}\t{}\t{}\n'.format(id, s, p))
if __name__ == '__main__':
prepare_logger(log)
print_arguments(args)
check_cuda(args.use_cuda)
main(args)
......@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u run_mrc.py --use_cuda true\
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 16 \
--in_tokens false\
--use_fast_executor true \
......
......@@ -4,7 +4,13 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \
--verbose true \
--do_train true \
......
......@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u run_mrc.py --use_cuda true\
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 16 \
--in_tokens false\
--use_fast_executor true \
......
......@@ -2,7 +2,7 @@ set -eux
export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLDE_DEVICES=0
python -u run_sequence_labeling.py \
--use_cuda true \
......@@ -15,7 +15,7 @@ python -u run_sequence_labeling.py \
--chunk_scheme "IOB" \
--label_map_config ${TASK_DATA_PATH}/msra_ner/label_map.json \
--train_set ${TASK_DATA_PATH}/msra_ner/train.tsv \
--dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv \
--dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv,${TASK_DATA_PATH}/msra_ner/test.tsv \
--test_set ${TASK_DATA_PATH}/msra_ner/test.tsv \
--vocab_path ${MODEL_PATH}/vocab.txt \
--ernie_config_path ${MODEL_PATH}/ernie_config.json \
......@@ -24,6 +24,7 @@ python -u run_sequence_labeling.py \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--validation_steps 100 \
--use_fp16 false \
--epoch 6 \
--max_seq_len 256 \
--learning_rate 5e-5 \
......
......@@ -4,29 +4,36 @@ export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \
--use_cuda true \
--do_train true \
--do_val true \
--do_test false \
--verbose true \
--batch_size 8192 \
--in_tokens true \
--init_pretraining_params ${MODEL_PATH}/params \
--train_set ${TASK_DATA_PATH}/xnli/train.tsv \
--dev_set ${TASK_DATA_PATH}/xnli/dev.tsv,${TASK_DATA_PATH}/xnli/test.tsv \
--vocab_path ${MODEL_PATH}/vocab.txt \
--label_map ${TASK_DATA_PATH}/xnli/label_map.json \
--ernie_config_path ${MODEL_PATH}/ernie_config.json \
--checkpoints ./checkpoints \
--save_steps 1000 \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--validation_steps 25 \
--epoch 3 \
--max_seq_len 512 \
--learning_rate 1e-4 \
--skip_steps 10 \
--num_iteration_per_drop_scope 1 \
--num_labels 3 \
--random_seed 1
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \
--do_train true \
--do_val true \
--do_test false \
--verbose true \
--in_tokens true \
--batch_size 8192 \
--train_set ${TASK_DATA_PATH}/xnli/train.tsv \
--dev_set ${TASK_DATA_PATH}/xnli/dev.tsv,${TASK_DATA_PATH}/xnli/test.tsv \
--label_map ${TASK_DATA_PATH}/xnli/label_map.json \
--vocab_path ${MODEL_PATH}/vocab.txt \
--ernie_config_path ${MODEL_PATH}/ernie_config.json \
--init_pretraining_params ${MODEL_PATH}/params \
--checkpoints ./checkpoints \
--save_steps 1000 \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--use_fp16 false \
--validation_steps 100 \
--epoch 3 \
--max_seq_len 512 \
--learning_rate 1e-4 \
--skip_steps 10 \
--num_iteration_per_drop_scope 1 \
--num_labels 3 \
--random_seed 1
......@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_mrc.py --use_cuda true\
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 8 \
--in_tokens false\
--use_fast_executor true \
......
......@@ -3,7 +3,12 @@ set -eux
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \
--verbose true \
--do_train true \
......
......@@ -4,7 +4,12 @@ export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_mrc.py --use_cuda true\
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_mrc.py --use_cuda true\
--batch_size 8 \
--in_tokens false\
--use_fast_executor true \
......
......@@ -14,15 +14,16 @@ python -u run_sequence_labeling.py \
--chunk_scheme "IOB" \
--label_map_config ${TASK_DATA_PATH}/msra_ner/label_map.json \
--train_set ${TASK_DATA_PATH}/msra_ner/train.tsv \
--dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv \
--dev_set ${TASK_DATA_PATH}/msra_ner/dev.tsv,${TASK_DATA_PATH}/msra_ner/test.tsv \
--test_set ${TASK_DATA_PATH}/msra_ner/test.tsv \
--vocab_path config/vocab.txt \
--ernie_config_path config/ernie_config.json \
--vocab_path ${MODEL_PATH}/vocab.txt \
--ernie_config_path ${MODEL_PATH}/ernie_config.json \
--checkpoints ./checkpoints \
--save_steps 100000 \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--validation_steps 100 \
--use_fp16 false \
--epoch 6 \
--max_seq_len 256 \
--learning_rate 1e-5 \
......
......@@ -3,7 +3,13 @@ set -eux
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u run_classifier.py \
python ./finetune_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
run_classifier.py \
--use_cuda true \
--do_train true \
--do_val true \
......
......@@ -3,8 +3,12 @@ set -eux
export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u ./train.py --use_cuda True \
python ./pretrain_launch.py \
--nproc_per_node 8 \
--selected_gpus 0,1,2,3,4,5,6,7 \
--node_ips $(hostname -i) \
--node_id 0 \
./train.py --use_cuda True \
--is_distributed False\
--use_fast_executor True \
--weight_sharing True \
......@@ -19,6 +23,7 @@ python -u ./train.py --use_cuda True \
--save_steps 10000 \
--ernie_config_path ./config/ernie_config.json \
--learning_rate 1e-4 \
--use_fp16 false \
--weight_decay 0.01 \
--max_seq_len 512 \
--skip_steps 10
......@@ -17,6 +17,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from io import open
import collections
import unicodedata
......@@ -69,15 +73,15 @@ def printable_text(text):
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
with open(vocab_file, encoding='utf8') as fin:
for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
......
......@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""ERNIE pretraining."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import time
import multiprocessing
import logging
import numpy as np
import paddle.fluid as fluid
......@@ -27,11 +29,12 @@ import paddle.fluid as fluid
from reader.pretraining import ErnieDataReader
from model.ernie_v1 import ErnieModel, ErnieConfig
from optimization import optimization
from utils.args import print_arguments, check_cuda
from utils.args import print_arguments, check_cuda, prepare_logger
from utils.init import init_checkpoint, init_pretraining_params
from pretrain_args import parser
log = logging.getLogger()
args = parser.parse_args()
# yapf: enable.
......@@ -65,9 +68,6 @@ def create_model(pyreader_name, ernie_config):
next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output(
mask_label, mask_pos, labels)
if args.use_fp16 and args.loss_scaling > 1.0:
total_loss *= args.loss_scaling
return pyreader, next_sent_acc, mask_lm_loss, total_loss
......@@ -114,7 +114,7 @@ def predict_wrapper(args,
cost += each_total_cost
steps += 1
if args.do_test and steps % args.skip_steps == 0:
print("[test_set] steps: %d" % steps)
log.info("[test_set] steps: %d" % steps)
except fluid.core.EOFException:
pyreader.reset()
......@@ -151,9 +151,9 @@ def test(args):
pyreader=test_pyreader,
fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name])
print("test begin")
log.info("test begin")
loss, lm_loss, acc, steps, speed = predict()
print(
log.info(
"[test_set] loss: %f, global ppl: %f, next_sent_acc: %f, speed: %f steps/s"
% (np.mean(np.array(loss) / steps),
np.exp(np.mean(np.array(lm_loss) / steps)),
......@@ -161,7 +161,7 @@ def test(args):
def train(args):
print("pretraining start")
log.info("pretraining start")
ernie_config = ErnieConfig(args.ernie_config_path)
ernie_config.print_config()
......@@ -171,7 +171,7 @@ def train(args):
with fluid.unique_name.guard():
train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
pyreader_name='train_reader', ernie_config=ernie_config)
scheduled_lr, loss_scaling = optimization(
scheduled_lr, _ = optimization(
loss=total_loss,
warmup_steps=args.warmup_steps,
num_train_steps=args.num_train_steps,
......@@ -180,7 +180,14 @@ def train(args):
startup_prog=startup_prog,
weight_decay=args.weight_decay,
scheduler=args.lr_scheduler,
use_fp16=args.use_fp16)
use_fp16=args.use_fp16,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
init_loss_scaling=args.init_loss_scaling,
incr_every_n_steps=args.incr_every_n_steps,
decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
incr_ratio=args.incr_ratio,
decr_ratio=args.decr_ratio)
fluid.memory_optimize(
input_program=train_program,
......@@ -196,31 +203,34 @@ def train(args):
test_prog = test_prog.clone(for_test=True)
if len(fluid.cuda_places()) == 0:
raise RuntimeError('not cuda device cound, check ur env setting')
if args.use_cuda:
place = fluid.CUDAPlace(0)
place = fluid.cuda_places()[0]
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("Device count %d" % dev_count)
print("theoretical memory usage: ")
print(fluid.contrib.memory_usage(
log.info("Device count %d" % dev_count)
log.info("theoretical memory usage: ")
log.info(fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size // args.max_seq_len))
nccl2_num_trainers = 1
nccl2_trainer_id = 0
print("args.is_distributed:", args.is_distributed)
log.info("args.is_distributed: %s" % args.is_distributed)
if args.is_distributed:
worker_endpoints_env = os.getenv("worker_endpoints")
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
current_endpoint = os.getenv("current_endpoint")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
trainer_id = worker_endpoints.index(current_endpoint)
if trainer_id == 0:
print("train_id == 0, sleep 60s")
log.info("train_id == 0, sleep 60s")
time.sleep(60)
print("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
log.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
......@@ -309,13 +319,13 @@ def train(args):
lm_cost.extend(each_mask_lm_cost)
cost.extend(each_total_cost)
print("feed_queue size", train_pyreader.queue.size())
log.info("feed_queue size %d" % train_pyreader.queue.size())
time_end = time.time()
used_time = time_end - time_begin
epoch, current_file_index, total_file, current_file, mask_type = data_reader.get_progress(
)
print("current learning_rate:%f" % np_lr[0])
print(
log.info("current learning_rate:%f" % np_lr[0])
log.info(
"epoch: %d, progress: %d/%d, step: %d, loss: %f, "
"ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s, mask_type: %s"
% (epoch, current_file_index, total_file, steps,
......@@ -335,7 +345,7 @@ def train(args):
if args.valid_filelist and steps % args.validation_steps == 0:
vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict(
)
print("[validation_set] epoch: %d, step: %d, "
log.info("[validation_set] epoch: %d, step: %d, "
"loss: %f, global ppl: %f, batch-averged ppl: %f, "
"next_sent_acc: %f, speed: %f steps/s" %
(epoch, steps, np.mean(np.array(vali_cost) / vali_steps),
......@@ -349,6 +359,7 @@ def train(args):
if __name__ == '__main__':
prepare_logger(log)
print_arguments(args)
check_cuda(args.use_cuda)
if args.do_test:
......
......@@ -12,17 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import six
import argparse
import logging
import paddle.fluid as fluid
log = logging.getLogger(__name__)
def prepare_logger(logger, debug=False, save_to_file=None):
formatter = logging.Formatter(fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s')
console_hdl = logging.StreamHandler()
console_hdl.setFormatter(formatter)
logger.addHandler(console_hdl)
if save_to_file is not None and not os.path.exits(save_to_file):
file_hdl = logging.FileHandler(save_to_file)
file_hdl.setFormatter(formatter)
logger.addHandler(file_hdl)
logger.setLevel(logging.DEBUG)
logger.propagate = False
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
......@@ -33,10 +51,11 @@ class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
def add_arg(self, name, type, default, help, positional_arg=False, **kwargs):
prefix = "" if positional_arg else "--"
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
prefix + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
......@@ -44,10 +63,10 @@ class ArgumentGroup(object):
def print_arguments(args):
print('----------- Configuration Arguments -----------')
log.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
def check_cuda(use_cuda, err = \
......@@ -56,7 +75,7 @@ def check_cuda(use_cuda, err = \
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
log.error(err)
sys.exit(1)
except Exception as e:
pass
......@@ -11,7 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Evaluation script for CMRC 2018
version: v5
......@@ -6,22 +19,25 @@ Note:
v5 formatted output, add usage description
v4 fixed segmentation issues
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from collections import Counter, OrderedDict
import string
import re
import argparse
import json
import sys
reload(sys)
sys.setdefaultencoding('utf8')
import nltk
import pdb
# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = str(in_str).decode('utf-8').lower().strip()
in_str = in_str.lower().strip()
segs_out = []
temp_str = ""
sp_char = [
......@@ -32,7 +48,7 @@ def mixed_segmentation(in_str, rm_punc=False):
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char:
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
......@@ -51,7 +67,7 @@ def mixed_segmentation(in_str, rm_punc=False):
# remove punctuation
def remove_punctuation(in_str):
in_str = str(in_str).decode('utf-8').lower().strip()
in_str = in_str.lower().strip()
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
......@@ -102,7 +118,7 @@ def evaluate(ground_truth_file, prediction_file):
skip_count += 1
continue
prediction = str(prediction_file[query_id])
prediction = prediction_file[query_id]
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
......
......@@ -16,27 +16,20 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
def cast_fp16_to_fp32(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP16,
"out_dtype": fluid.core.VarDesc.VarType.FP32
})
def cast_fp32_to_fp16(i, o, prog):
def append_cast_op(i, o, prog):
"""
Append a cast op in a given Program to cast input `i` to data type `o.dtype`.
Args:
i (Variable): The input Variable.
o (Variable): The output Variable.
prog (Program): The Program to append cast op.
"""
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP32,
"out_dtype": fluid.core.VarDesc.VarType.FP16
})
attrs={"in_dtype": i.dtype,
"out_dtype": o.dtype})
def copy_to_master_param(p, block):
......@@ -59,32 +52,66 @@ def copy_to_master_param(p, block):
return new_p
def apply_dynamic_loss_scaling(loss_scaling, master_params_grads,
incr_every_n_steps, decr_every_n_nan_or_inf,
incr_ratio, decr_ratio):
_incr_every_n_steps = fluid.layers.fill_constant(
shape=[1], dtype='int32', value=incr_every_n_steps)
_decr_every_n_nan_or_inf = fluid.layers.fill_constant(
shape=[1], dtype='int32', value=decr_every_n_nan_or_inf)
_num_good_steps = fluid.layers.create_global_var(
name=fluid.unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
_num_bad_steps = fluid.layers.create_global_var(
name=fluid.unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
grads = [fluid.layers.reduce_sum(g) for [_, g] in master_params_grads]
all_grads = fluid.layers.concat(grads)
all_grads_sum = fluid.layers.reduce_sum(all_grads)
is_overall_finite = fluid.layers.isfinite(all_grads_sum)
update_loss_scaling(is_overall_finite, loss_scaling, _num_good_steps,
_num_bad_steps, _incr_every_n_steps,
_decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with fluid.layers.Switch() as switch:
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in master_params_grads:
fluid.layers.assign(fluid.layers.zeros_like(g), g)
def create_master_params_grads(params_grads, main_prog, startup_prog,
loss_scaling):
master_params_grads = []
tmp_role = main_prog._current_role
OpRole = fluid.core.op_proto_and_checker_maker.OpRole
main_prog._current_role = OpRole.Backward
for p, g in params_grads:
with main_prog._optimized_guard([p, g]):
# create master parameters
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("layer_norm") > -1:
if loss_scaling > 1:
scaled_g = g / float(loss_scaling)
else:
scaled_g = g
master_params_grads.append([p, scaled_g])
continue
master_grad = fluid.layers.cast(g, "float32")
if loss_scaling > 1:
master_grad = master_grad / float(loss_scaling)
master_params_grads.append([master_param, master_grad])
main_prog._current_role = tmp_role
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("layer_norm") > -1:
scaled_g = g / loss_scaling
master_params_grads.append([p, scaled_g])
continue
master_grad = fluid.layers.cast(g, "float32")
master_grad = master_grad / loss_scaling
master_params_grads.append([master_param, master_grad])
return master_params_grads
......@@ -94,4 +121,80 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
if train_p.name.find("layer_norm") > -1:
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
cast_fp32_to_fp16(m_p_g[0], train_p, main_prog)
append_cast_op(m_p_g[0], train_p, main_prog)
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwisw, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
with fluid.layers.Switch() as switch:
with switch.case(is_overall_finite):
should_incr_loss_scaling = fluid.layers.less_than(
incr_every_n_steps, num_good_steps + 1)
with fluid.layers.Switch() as switch1:
with switch1.case(should_incr_loss_scaling):
new_loss_scaling = prev_loss_scaling * incr_ratio
loss_scaling_is_finite = fluid.layers.isfinite(
new_loss_scaling)
with fluid.layers.Switch() as switch2:
with switch2.case(loss_scaling_is_finite):
fluid.layers.assign(new_loss_scaling,
prev_loss_scaling)
with switch2.default():
pass
fluid.layers.assign(zero_steps, num_good_steps)
fluid.layers.assign(zero_steps, num_bad_steps)
with switch1.default():
fluid.layers.increment(num_good_steps)
fluid.layers.assign(zero_steps, num_bad_steps)
with switch.default():
should_decr_loss_scaling = fluid.layers.less_than(
decr_every_n_nan_or_inf, num_bad_steps + 1)
with fluid.layers.Switch() as switch3:
with switch3.case(should_decr_loss_scaling):
new_loss_scaling = prev_loss_scaling * decr_ratio
static_loss_scaling = \
fluid.layers.fill_constant(shape=[1],
dtype='float32',
value=1.0)
less_than_one = fluid.layers.less_than(new_loss_scaling,
static_loss_scaling)
with fluid.layers.Switch() as switch4:
with switch4.case(less_than_one):
fluid.layers.assign(static_loss_scaling,
prev_loss_scaling)
with switch4.default():
fluid.layers.assign(new_loss_scaling,
prev_loss_scaling)
fluid.layers.assign(zero_steps, num_good_steps)
fluid.layers.assign(zero_steps, num_bad_steps)
with switch3.default():
fluid.layers.assign(zero_steps, num_good_steps)
fluid.layers.increment(num_bad_steps)
......@@ -12,27 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
import os
import six
import ast
import copy
import logging
import numpy as np
import paddle.fluid as fluid
log = logging.getLogger(__name__)
def cast_fp32_to_fp16(exe, main_program):
print("Cast parameters to float16 data format.")
log.info("Cast parameters to float16 data format.")
for param in main_program.global_block().all_parameters():
if not param.name.endswith(".master"):
param_t = fluid.global_scope().find_var(param.name).get_tensor()
data = np.array(param_t)
if param.name.find("layer_norm") == -1:
if param.name.startswith("encoder_layer") \
and "layer_norm" not in param.name:
param_t.set(np.float16(data).view(np.uint16), exe.place)
master_param_var = fluid.global_scope().find_var(param.name +
".master")
#load fp32
master_param_var = fluid.global_scope().find_var(param.name +
".master")
if master_param_var is not None:
master_param_var.get_tensor().set(data, exe.place)
......@@ -40,7 +50,7 @@ def cast_fp32_to_fp16(exe, main_program):
def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
......@@ -51,7 +61,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
log.info("Load model from {}".format(init_checkpoint_path))
if use_fp16:
cast_fp32_to_fp16(exe, main_program)
......@@ -74,7 +84,7 @@ def init_pretraining_params(exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
print("Load pretraining parameters from {}.".format(
log.info("Load pretraining parameters from {}.".format(
pretraining_params_path))
if use_fp16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册