提交 968a114e 编写于 作者: M mir-of

fix bert_benchmark/run_pretraining.py for of-develop

上级 81379888
......@@ -3,8 +3,6 @@ from __future__ import division
from __future__ import print_function
import os
import time
import random
import argparse
from datetime import datetime
......@@ -15,107 +13,121 @@ import benchmark_util
parser = argparse.ArgumentParser(description="flags for bert")
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
# resouce
parser.add_argument("--gpu_num_per_node", type=int, default=1)
parser.add_argument("--node_num", type=int, default=1)
parser.add_argument("--node_list", type=str, default=None)
parser.add_argument(
"--gpu_num_per_node", type=int, default=1)
parser.add_argument(
"--node_num", type=int, default=1)
parser.add_argument(
"--node_list", type=str, default=None)
# train
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument(
"--weight_l2", type=float, default=0.01, help="weight l2 decay parameter"
)
parser.add_argument("--batch_size_per_device", type=int, default=24)
parser.add_argument("--iter_num", type=int, default=10, help="total iterations to run")
"--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument(
"--warmup_iter_num", type=int, default=10, help="total iterations to run"
)
"--weight_l2", type=float, default=0.01, help="weight l2 decay parameter")
parser.add_argument(
"--log_every_n_iter", type=int, default=1, help="print loss every n iteration"
)
"--batch_size_per_device", type=int, default=24)
parser.add_argument(
"--iter_num", type=int, default=10, help="total iterations to run")
parser.add_argument(
"--warmup_iter_num", type=int, default=10, help="total iterations to run")
parser.add_argument(
"--log_every_n_iter", type=int, default=1,
help="print loss every n iteration")
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument(
"--data_part_num", type=int, default=32, help="data part number in dataset"
"--data_part_num", type=int, default=32,
help="data part number in dataset")
# parser.add_argument(
# "--enable_auto_mixed_precision", type=bool, default=False)
parser.add_argument(
'--use_fp16',
type=str2bool,
nargs='?',
const=True,
help='Whether to use use fp16'
)
parser.add_argument(
'--use_boxing_v2',
type=str2bool,
nargs='?',
const=True,
help='Whether to use boxing v2'
)
parser.add_argument("--enable_auto_mixed_precision", type=bool, default=False)
# log and resore/save
parser.add_argument(
"--loss_print_every_n_iter",
type=int,
default=1,
required=False,
help="print loss every n iteration",
)
type=int, default=1, required=False, help="print loss every n iteration")
parser.add_argument(
"--model_save_every_n_iter",
type=int,
default=200,
required=False,
help="save model every n iteration",
)
"--model_save_every_n_iter", type=int, default=200, required=False,
help="save model every n iteration",)
parser.add_argument(
"--model_save_dir",
type=str,
default="./output/model_save-{}".format(
str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
),
required=False,
help="model save directory",
)
"--model_save_dir", type=str, default="./output/model_save-{}".format(
str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))),
required=False, help="model save directory")
parser.add_argument(
"--save_last_snapshot",
type=bool,
default=False,
required=False,
help="save model snapshot for last iteration",
)
"--save_last_snapshot", type=bool, default=False, required=False,
help="save model snapshot for last iteration")
parser.add_argument(
"--model_load_dir",
type=str,
default=None,
required=False,
help="model load directory",
)
"--model_load_dir", type=str, default=None, required=False,
help="model load directory")
parser.add_argument(
"--log_dir",
type=str,
default="./output",
required=False,
help="log info save directory",
)
"--log_dir", type=str, default="./output", required=False,
help="log info save directory")
# bert
parser.add_argument("--seq_length", type=int, default=512)
parser.add_argument("--max_predictions_per_seq", type=int, default=80)
parser.add_argument("--num_hidden_layers", type=int, default=24)
parser.add_argument("--num_attention_heads", type=int, default=16)
parser.add_argument("--max_position_embeddings", type=int, default=512)
parser.add_argument("--type_vocab_size", type=int, default=2)
parser.add_argument("--vocab_size", type=int, default=30522)
parser.add_argument("--attention_probs_dropout_prob", type=float, default=0.1)
parser.add_argument("--hidden_dropout_prob", type=float, default=0.1)
parser.add_argument("--hidden_size_per_head", type=int, default=64)
parser.add_argument(
"--seq_length", type=int, default=512)
parser.add_argument(
"--max_predictions_per_seq", type=int, default=80)
parser.add_argument(
"--num_hidden_layers", type=int, default=24)
parser.add_argument(
"--num_attention_heads", type=int, default=16)
parser.add_argument(
"--max_position_embeddings", type=int, default=512)
parser.add_argument(
"--type_vocab_size", type=int, default=2)
parser.add_argument(
"--vocab_size", type=int, default=30522)
parser.add_argument(
"--attention_probs_dropout_prob", type=float, default=0.1)
parser.add_argument(
"--hidden_dropout_prob", type=float, default=0.1)
parser.add_argument(
"--hidden_size_per_head", type=int, default=64)
args = parser.parse_args()
def _blob_conf(name, shape, dtype=flow.int32):
return flow.data.BlobConf(
name=name, shape=shape, dtype=dtype, codec=flow.data.RawCodec()
)
def BertDecoder(
data_dir, batch_size, data_part_num, seq_length, max_predictions_per_seq
):
def _blob_conf(name, shape, dtype=flow.int32):
return flow.data.BlobConf(
name=name, shape=shape, dtype=dtype, codec=flow.data.RawCodec()
)
blob_confs = []
blob_confs.append(_blob_conf("input_ids", [seq_length]))
blob_confs.append(_blob_conf("next_sentence_labels", [1]))
blob_confs.append(_blob_conf("input_mask", [seq_length]))
blob_confs.append(_blob_conf("segment_ids", [seq_length]))
blob_confs.append(_blob_conf("masked_lm_ids", [max_predictions_per_seq]))
blob_confs.append(_blob_conf("masked_lm_positions", [max_predictions_per_seq]))
blob_confs.append(_blob_conf(
"masked_lm_positions", [max_predictions_per_seq]))
blob_confs.append(
_blob_conf("masked_lm_weights", [max_predictions_per_seq], flow.float)
)
......@@ -145,7 +157,8 @@ def BuildPreTrainNet(
intermediate_size = hidden_size * 4
decoders = BertDecoder(
args.data_dir, batch_size, data_part_num, seq_length, max_predictions_per_seq
args.data_dir, batch_size, data_part_num, seq_length,
max_predictions_per_seq
)
input_ids = decoders[0]
......@@ -183,21 +196,29 @@ _BERT_MODEL_UPDATE_CONF = dict(
learning_rate_decay=dict(
polynomial_conf=dict(decay_batches=100000, end_learning_rate=0.0,)
),
warmup_conf=dict(linear_conf=dict(warmup_batches=1000, start_multiplier=0,)),
warmup_conf=dict(linear_conf=dict(
warmup_batches=1000, start_multiplier=0,)),
clip_conf=dict(clip_by_global_norm=dict(clip_norm=1.0,)),
adam_conf=dict(epsilon=1e-6),
)
config = flow.function_config()
config.default_data_type(flow.float)
config.train.primary_lr(args.learning_rate)
config.train.model_update_conf(_BERT_MODEL_UPDATE_CONF)
# config.train.weight_l2(args.weight_l2) ??
if args.use_fp16:
config.enable_auto_mixed_precision(True)
if args.use_boxing_v2:
config.use_boxing_v2(True)
@flow.function
@flow.function(config)
def PretrainJob():
total_device_num = args.node_num * args.gpu_num_per_node
batch_size = total_device_num * args.batch_size_per_device
flow.config.train.primary_lr(args.learning_rate)
flow.config.train.model_update_conf(_BERT_MODEL_UPDATE_CONF)
flow.config.train.weight_l2(args.weight_l2)
total_loss, mlm_loss, nsp_loss = BuildPreTrainNet(
batch_size,
args.data_part_num,
......@@ -226,13 +247,17 @@ def main():
for arg in vars(args):
print("{} = {}".format(arg, getattr(args, arg)))
print("-".ljust(66, "-"))
print("Time stamp: {}".format(str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
print("Time stamp: {}".format(
str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
flow.config.gpu_device_num(args.gpu_num_per_node)
flow.config.default_data_type(flow.float)
flow.env.log_dir(args.log_dir)
if args.enable_auto_mixed_precision:
flow.config.enable_auto_mixed_precision()
if args.use_boxing_v2:
flow.config.collective_boxing.nccl_fusion_threshold_mb(8)
flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)
# if args.enable_auto_mixed_precision:
# flow.config.enable_auto_mixed_precision()
if args.node_num > 1:
nodes = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册