提交 0eec024f 编写于 作者: Y Yao Chi

depracated calls have been replaced by new apis

上级 f60ce103
...@@ -66,53 +66,45 @@ class CNNSpeedometer: ...@@ -66,53 +66,45 @@ class CNNSpeedometer:
class BERTSpeedometer: class BERTSpeedometer:
def __init__(self): def __init__(self):
self.watch = StopWatch() self.watch = StopWatch()
self.watch.start()
def speedometer_cb( def speedometer_cb(
self, step, total_batch_size, warmup_num, iter_num, loss_print_every_n_iter self, step, total_batch_size, iter_num, loss_print_every_n_iter
): ):
def callback(train_loss): def callback(train_loss):
if step < warmup_num: train_step = step
print( if (train_step + 1) % loss_print_every_n_iter == 0:
"Runing warm up for {}/{} iterations.".format(step + 1, warmup_num) total_loss = train_loss[0].mean()
mlm_loss = train_loss[1].mean()
nsp_loss = train_loss[2].mean()
duration = self.watch.split()
sentences_per_sec = (
total_batch_size * loss_print_every_n_iter / duration
) )
if (step + 1) == warmup_num: print(
self.watch.start() "iter {}, total_loss: {:.3f}, mlm_loss: {:.3f}, nsp_loss: {:.3f}, speed: {:.3f}(sec/batch), {:.3f}(sentences/sec)".format(
print("Start trainning.") train_step,
else: total_loss,
train_step = step - warmup_num mlm_loss,
nsp_loss,
if (train_step + 1) % loss_print_every_n_iter == 0: duration,
total_loss = train_loss[0].mean() sentences_per_sec,
mlm_loss = train_loss[1].mean()
nsp_loss = train_loss[2].mean()
duration = self.watch.split()
sentences_per_sec = (
total_batch_size * loss_print_every_n_iter / duration
)
print(
"iter {}, total_loss: {:.3f}, mlm_loss: {:.3f}, nsp_loss: {:.3f}, speed: {:.3f}(sec/batch), {:.3f}(sentences/sec)".format(
train_step,
total_loss,
mlm_loss,
nsp_loss,
duration,
sentences_per_sec,
)
) )
)
if (train_step + 1) == iter_num: if (train_step + 1) == iter_num:
self.watch.stop() self.watch.stop()
totoal_duration = self.watch.duration() totoal_duration = self.watch.duration()
avg_sentences_per_sec = ( avg_sentences_per_sec = (
total_batch_size * iter_num / totoal_duration total_batch_size * iter_num / totoal_duration
) )
print("-".ljust(66, "-")) print("-".ljust(66, "-"))
print( print(
"average speed: {:.3f}(sentences/sec)".format( "average speed: {:.3f}(sentences/sec)".format(
avg_sentences_per_sec avg_sentences_per_sec
)
) )
print("-".ljust(66, "-")) )
print("-".ljust(66, "-"))
return callback return callback
...@@ -22,8 +22,8 @@ class BertBackbone(object): ...@@ -22,8 +22,8 @@ class BertBackbone(object):
type_vocab_size=16, type_vocab_size=16,
initializer_range=0.02): initializer_range=0.02):
with flow.deprecated.variable_scope("bert"): with flow.name_scope("bert"):
with flow.deprecated.variable_scope("embeddings"): with flow.name_scope("embeddings"):
(self.embedding_output_, self.embedding_table_) = _EmbeddingLookup( (self.embedding_output_, self.embedding_table_) = _EmbeddingLookup(
input_ids_blob=input_ids_blob, input_ids_blob=input_ids_blob,
vocab_size=vocab_size, vocab_size=vocab_size,
...@@ -43,7 +43,7 @@ class BertBackbone(object): ...@@ -43,7 +43,7 @@ class BertBackbone(object):
initializer_range=initializer_range, initializer_range=initializer_range,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
dropout_prob=hidden_dropout_prob) dropout_prob=hidden_dropout_prob)
with flow.deprecated.variable_scope("encoder"): with flow.name_scope("encoder"):
attention_mask_blob = _CreateAttentionMaskFromInputMask( attention_mask_blob = _CreateAttentionMaskFromInputMask(
input_mask_blob, from_seq_length=seq_length, to_seq_length=seq_length) input_mask_blob, from_seq_length=seq_length, to_seq_length=seq_length)
self.all_encoder_layers_ = _TransformerModel( self.all_encoder_layers_ = _TransformerModel(
...@@ -91,10 +91,10 @@ def _TransformerModel(input_blob, ...@@ -91,10 +91,10 @@ def _TransformerModel(input_blob,
prev_output_blob = flow.reshape(input_blob, (-1, input_width)) prev_output_blob = flow.reshape(input_blob, (-1, input_width))
all_layer_output_blobs = [] all_layer_output_blobs = []
for layer_idx in range(num_hidden_layers): for layer_idx in range(num_hidden_layers):
with flow.deprecated.variable_scope("layer_%d"%layer_idx): with flow.name_scope("layer_%d"%layer_idx):
layer_input_blob = prev_output_blob layer_input_blob = prev_output_blob
with flow.deprecated.variable_scope("attention"): with flow.name_scope("attention"):
with flow.deprecated.variable_scope("self"): with flow.name_scope("self"):
attention_output_blob = _AttentionLayer( attention_output_blob = _AttentionLayer(
from_blob=layer_input_blob, from_blob=layer_input_blob,
to_blob=layer_input_blob, to_blob=layer_input_blob,
...@@ -106,7 +106,7 @@ def _TransformerModel(input_blob, ...@@ -106,7 +106,7 @@ def _TransformerModel(input_blob,
do_return_2d_tensor=True, do_return_2d_tensor=True,
from_seq_length=seq_length, from_seq_length=seq_length,
to_seq_length=seq_length) to_seq_length=seq_length)
with flow.deprecated.variable_scope("output"): with flow.name_scope("output"):
attention_output_blob = _FullyConnected( attention_output_blob = _FullyConnected(
attention_output_blob, attention_output_blob,
input_size=num_attention_heads * attention_head_size, input_size=num_attention_heads * attention_head_size,
...@@ -116,7 +116,7 @@ def _TransformerModel(input_blob, ...@@ -116,7 +116,7 @@ def _TransformerModel(input_blob,
attention_output_blob = _Dropout(attention_output_blob, hidden_dropout_prob) attention_output_blob = _Dropout(attention_output_blob, hidden_dropout_prob)
attention_output_blob = attention_output_blob + layer_input_blob attention_output_blob = attention_output_blob + layer_input_blob
attention_output_blob = _LayerNorm(attention_output_blob, hidden_size) attention_output_blob = _LayerNorm(attention_output_blob, hidden_size)
with flow.deprecated.variable_scope("intermediate"): with flow.name_scope("intermediate"):
if callable(intermediate_act_fn): if callable(intermediate_act_fn):
act_fn = op_conf_util.kNone act_fn = op_conf_util.kNone
else: else:
...@@ -130,7 +130,7 @@ def _TransformerModel(input_blob, ...@@ -130,7 +130,7 @@ def _TransformerModel(input_blob,
name='dense') name='dense')
if callable(intermediate_act_fn): if callable(intermediate_act_fn):
intermediate_output_blob = intermediate_act_fn(intermediate_output_blob) intermediate_output_blob = intermediate_act_fn(intermediate_output_blob)
with flow.deprecated.variable_scope("output"): with flow.name_scope("output"):
layer_output_blob = _FullyConnected( layer_output_blob = _FullyConnected(
intermediate_output_blob, intermediate_output_blob,
input_size=intermediate_size, input_size=intermediate_size,
......
...@@ -65,13 +65,13 @@ def PreTrain( ...@@ -65,13 +65,13 @@ def PreTrain(
hidden_size=hidden_size, hidden_size=hidden_size,
initializer_range=initializer_range, initializer_range=initializer_range,
) )
with flow.deprecated.variable_scope("cls-loss"): with flow.name_scope("cls-loss"):
total_loss = lm_loss + ns_loss total_loss = lm_loss + ns_loss
return total_loss, lm_loss, ns_loss return total_loss, lm_loss, ns_loss
def PooledOutput(sequence_output, hidden_size, initializer_range): def PooledOutput(sequence_output, hidden_size, initializer_range):
with flow.deprecated.variable_scope("bert-pooler"): with flow.name_scope("bert-pooler"):
first_token_tensor = flow.slice(sequence_output, [None, 0, 0], [None, 1, -1]) first_token_tensor = flow.slice(sequence_output, [None, 0, 0], [None, 1, -1])
first_token_tensor = flow.reshape(first_token_tensor, [-1, hidden_size]) first_token_tensor = flow.reshape(first_token_tensor, [-1, hidden_size])
pooled_output = bert_util._FullyConnected( pooled_output = bert_util._FullyConnected(
...@@ -98,15 +98,15 @@ def _AddMaskedLanguageModelLoss( ...@@ -98,15 +98,15 @@ def _AddMaskedLanguageModelLoss(
hidden_act, hidden_act,
initializer_range, initializer_range,
): ):
with flow.deprecated.variable_scope("other"): with flow.name_scope("other"):
sum_label_weight_blob = flow.math.reduce_sum(label_weight_blob, axis=[-1]) sum_label_weight_blob = flow.math.reduce_sum(label_weight_blob, axis=[-1])
ones = sum_label_weight_blob * 0.0 + 1.0 ones = sum_label_weight_blob * 0.0 + 1.0
sum_label_weight_blob = flow.math.reduce_sum(sum_label_weight_blob) sum_label_weight_blob = flow.math.reduce_sum(sum_label_weight_blob)
batch_size = flow.math.reduce_sum(ones) batch_size = flow.math.reduce_sum(ones)
sum_label_weight_blob = sum_label_weight_blob / batch_size sum_label_weight_blob = sum_label_weight_blob / batch_size
with flow.deprecated.variable_scope("cls-predictions"): with flow.name_scope("cls-predictions"):
input_blob = _GatherIndexes(input_blob, positions_blob, seq_length, hidden_size) input_blob = _GatherIndexes(input_blob, positions_blob, seq_length, hidden_size)
with flow.deprecated.variable_scope("transform"): with flow.name_scope("transform"):
if callable(hidden_act): if callable(hidden_act):
act_fn = op_conf_util.kNone act_fn = op_conf_util.kNone
else: else:
...@@ -136,7 +136,7 @@ def _AddMaskedLanguageModelLoss( ...@@ -136,7 +136,7 @@ def _AddMaskedLanguageModelLoss(
) )
pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq]) pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq])
numerator = pre_example_loss * label_weight_blob numerator = pre_example_loss * label_weight_blob
with flow.deprecated.variable_scope("loss"): with flow.name_scope("loss"):
numerator = flow.math.reduce_sum(numerator, axis=[-1]) numerator = flow.math.reduce_sum(numerator, axis=[-1])
denominator = sum_label_weight_blob + 1e-5 denominator = sum_label_weight_blob + 1e-5
loss = numerator / denominator loss = numerator / denominator
...@@ -152,7 +152,7 @@ def _GatherIndexes(sequence_blob, positions_blob, seq_length, hidden_size): ...@@ -152,7 +152,7 @@ def _GatherIndexes(sequence_blob, positions_blob, seq_length, hidden_size):
def _AddNextSentenceOutput(input_blob, label_blob, hidden_size, initializer_range): def _AddNextSentenceOutput(input_blob, label_blob, hidden_size, initializer_range):
with flow.deprecated.variable_scope("cls-seq_relationship"): with flow.name_scope("cls-seq_relationship"):
output_weight_blob = flow.get_variable( output_weight_blob = flow.get_variable(
name="output_weights", name="output_weights",
shape=[2, hidden_size], shape=[2, hidden_size],
......
...@@ -13,7 +13,6 @@ import benchmark_util ...@@ -13,7 +13,6 @@ import benchmark_util
parser = argparse.ArgumentParser(description="flags for bert") parser = argparse.ArgumentParser(description="flags for bert")
def str2bool(v): def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'): if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True return True
...@@ -24,120 +23,69 @@ def str2bool(v): ...@@ -24,120 +23,69 @@ def str2bool(v):
# resouce # resouce
parser.add_argument( parser.add_argument("--gpu_num_per_node", type=int, default=1)
"--gpu_num_per_node", type=int, default=1) parser.add_argument("--node_num", type=int, default=1)
parser.add_argument( parser.add_argument("--node_list", type=str, default=None)
"--node_num", type=int, default=1)
parser.add_argument(
"--node_list", type=str, default=None)
# train # train
parser.add_argument( parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
"--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--weight_decay_rate", type=float, default=0.01, help="weight decay rate")
parser.add_argument( parser.add_argument("--batch_size_per_device", type=int, default=64)
"--weight_l2", type=float, default=0.01, help="weight l2 decay parameter") parser.add_argument("--iter_num", type=int, default=1144000, help="total iterations to run")
parser.add_argument( parser.add_argument("--warmup_batches", type=int, default=10000)
"--batch_size_per_device", type=int, default=24) parser.add_argument("--log_every_n_iter", type=int, default=1, help="print loss every n iteration")
parser.add_argument(
"--iter_num", type=int, default=10, help="total iterations to run")
parser.add_argument(
"--warmup_iter_num", type=int, default=10, help="total iterations to run")
parser.add_argument(
"--log_every_n_iter", type=int, default=1,
help="print loss every n iteration")
parser.add_argument("--data_dir", type=str, default=None) parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument( parser.add_argument("--data_part_num", type=int, default=32, help="data part number in dataset")
"--data_part_num", type=int, default=32, parser.add_argument('--use_fp16', type=str2bool, nargs='?', const=True, help='use use fp16 or not')
help="data part number in dataset") parser.add_argument('--use_boxing_v2', type=str2bool, nargs='?', const=True,
# parser.add_argument( help='use boxing v2 or not')
# "--enable_auto_mixed_precision", type=bool, default=False)
parser.add_argument(
'--use_fp16',
type=str2bool,
nargs='?',
const=True,
help='Whether to use use fp16'
)
parser.add_argument(
'--use_boxing_v2',
type=str2bool,
nargs='?',
const=True,
help='Whether to use boxing v2'
)
# log and resore/save # log and resore/save
parser.add_argument( parser.add_argument("--loss_print_every_n_iter", type=int, default=10, required=False,
"--loss_print_every_n_iter", help="print loss every n iteration")
type=int, default=1, required=False, help="print loss every n iteration") parser.add_argument("--model_save_every_n_iter", type=int, default=10000, required=False,
parser.add_argument(
"--model_save_every_n_iter", type=int, default=200, required=False,
help="save model every n iteration",) help="save model every n iteration",)
parser.add_argument( parser.add_argument("--model_save_dir", type=str,
"--model_save_dir", type=str, default="./output/model_save-{}".format( default="./output/model_save-{}".format(str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))),
str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))),
required=False, help="model save directory") required=False, help="model save directory")
parser.add_argument( parser.add_argument("--save_last_snapshot", type=bool, default=False, required=False,
"--save_last_snapshot", type=bool, default=False, required=False,
help="save model snapshot for last iteration") help="save model snapshot for last iteration")
parser.add_argument( parser.add_argument("--model_load_dir", type=str, default=None, help="model load directory")
"--model_load_dir", type=str, default=None, required=False, parser.add_argument("--log_dir", type=str, default="./output", help="log info save directory")
help="model load directory")
parser.add_argument(
"--log_dir", type=str, default="./output", required=False,
help="log info save directory")
# bert # bert
parser.add_argument( parser.add_argument("--seq_length", type=int, default=512)
"--seq_length", type=int, default=512) parser.add_argument("--max_predictions_per_seq", type=int, default=80)
parser.add_argument( parser.add_argument("--num_hidden_layers", type=int, default=24)
"--max_predictions_per_seq", type=int, default=80) parser.add_argument("--num_attention_heads", type=int, default=16)
parser.add_argument( parser.add_argument("--max_position_embeddings", type=int, default=512)
"--num_hidden_layers", type=int, default=24) parser.add_argument("--type_vocab_size", type=int, default=2)
parser.add_argument( parser.add_argument("--vocab_size", type=int, default=30522)
"--num_attention_heads", type=int, default=16) parser.add_argument("--attention_probs_dropout_prob", type=float, default=0.1)
parser.add_argument( parser.add_argument("--hidden_dropout_prob", type=float, default=0.1)
"--max_position_embeddings", type=int, default=512) parser.add_argument("--hidden_size_per_head", type=int, default=64)
parser.add_argument(
"--type_vocab_size", type=int, default=2)
parser.add_argument(
"--vocab_size", type=int, default=30522)
parser.add_argument(
"--attention_probs_dropout_prob", type=float, default=0.1)
parser.add_argument(
"--hidden_dropout_prob", type=float, default=0.1)
parser.add_argument(
"--hidden_size_per_head", type=int, default=64)
args = parser.parse_args() args = parser.parse_args()
def BertDecoder( def BertDecoder(data_dir, batch_size, data_part_num, seq_length, max_predictions_per_seq):
data_dir, batch_size, data_part_num, seq_length, max_predictions_per_seq ofrecord = flow.data.ofrecord_reader(data_dir,
): batch_size=batch_size,
data_part_num=data_part_num,
random_shuffle = True,
shuffle_after_epoch=True)
blob_confs = {}
def _blob_conf(name, shape, dtype=flow.int32): def _blob_conf(name, shape, dtype=flow.int32):
blob_confs[name] = flow.data.OFRecordRawDecoder(ofrecord, name, shape=shape, dtype=dtype)
return flow.data.BlobConf( _blob_conf("input_ids", [seq_length])
name=name, shape=shape, dtype=dtype, codec=flow.data.RawCodec() _blob_conf("next_sentence_labels", [1])
) _blob_conf("input_mask", [seq_length])
_blob_conf("segment_ids", [seq_length])
blob_confs = [] _blob_conf("masked_lm_ids", [max_predictions_per_seq])
blob_confs.append(_blob_conf("input_ids", [seq_length])) _blob_conf("masked_lm_positions", [max_predictions_per_seq])
blob_confs.append(_blob_conf("next_sentence_labels", [1])) _blob_conf("masked_lm_weights", [max_predictions_per_seq], flow.float)
blob_confs.append(_blob_conf("input_mask", [seq_length])) return blob_confs
blob_confs.append(_blob_conf("segment_ids", [seq_length]))
blob_confs.append(_blob_conf("masked_lm_ids", [max_predictions_per_seq]))
blob_confs.append(_blob_conf(
"masked_lm_positions", [max_predictions_per_seq]))
blob_confs.append(
_blob_conf("masked_lm_weights", [max_predictions_per_seq], flow.float)
)
return flow.data.decode_ofrecord(
data_dir,
blob_confs,
batch_size=batch_size,
name="decode",
data_part_num=data_part_num,
)
def BuildPreTrainNet( def BuildPreTrainNet(
...@@ -156,18 +104,16 @@ def BuildPreTrainNet( ...@@ -156,18 +104,16 @@ def BuildPreTrainNet(
hidden_size = 64 * num_attention_heads # , H = 64, size per head hidden_size = 64 * num_attention_heads # , H = 64, size per head
intermediate_size = hidden_size * 4 intermediate_size = hidden_size * 4
decoders = BertDecoder( decoders = BertDecoder(args.data_dir, batch_size, data_part_num, seq_length,
args.data_dir, batch_size, data_part_num, seq_length, max_predictions_per_seq)
max_predictions_per_seq
)
input_ids = decoders[0] input_ids = decoders["input_ids"]
next_sentence_labels = decoders[1] next_sentence_labels = decoders["next_sentence_labels"]
token_type_ids = decoders[2] input_mask = decoders["input_mask"]
input_mask = decoders[3] token_type_ids = decoders["segment_ids"]
masked_lm_ids = decoders[4] masked_lm_ids = decoders["masked_lm_ids"]
masked_lm_positions = decoders[5] masked_lm_positions = decoders["masked_lm_positions"]
masked_lm_weights = decoders[6] masked_lm_weights = decoders["masked_lm_weights"]
return PreTrain( return PreTrain(
input_ids, input_ids,
input_mask, input_mask,
...@@ -194,19 +140,27 @@ def BuildPreTrainNet( ...@@ -194,19 +140,27 @@ def BuildPreTrainNet(
_BERT_MODEL_UPDATE_CONF = dict( _BERT_MODEL_UPDATE_CONF = dict(
learning_rate_decay=dict( learning_rate_decay=dict(
polynomial_conf=dict(decay_batches=100000, end_learning_rate=0.0,) polynomial_conf=dict(
decay_batches=args.iter_num,
end_learning_rate=0.0,
)
),
warmup_conf=dict(
linear_conf=dict(warmup_batches=args.warmup_batches, start_multiplier=0,)
), ),
warmup_conf=dict(linear_conf=dict(
warmup_batches=1000, start_multiplier=0,)),
clip_conf=dict(clip_by_global_norm=dict(clip_norm=1.0,)), clip_conf=dict(clip_by_global_norm=dict(clip_norm=1.0,)),
adam_conf=dict(epsilon=1e-6), adam_conf=dict(epsilon=1e-6),
weight_decay_conf=dict(
weight_decay_rate=args.weight_decay_rate,
excludes=dict(pattern=["bias", "LayerNorm", "layer_norm"]),
),
) )
config = flow.function_config() config = flow.function_config()
config.default_data_type(flow.float) config.default_data_type(flow.float)
config.default_distribute_strategy(flow.distribute.consistent_strategy())
config.train.primary_lr(args.learning_rate) config.train.primary_lr(args.learning_rate)
config.train.model_update_conf(_BERT_MODEL_UPDATE_CONF) config.train.model_update_conf(_BERT_MODEL_UPDATE_CONF)
# config.train.weight_l2(args.weight_l2) ??
if args.use_fp16: if args.use_fp16:
config.enable_auto_mixed_precision(True) config.enable_auto_mixed_precision(True)
...@@ -214,7 +168,7 @@ if args.use_boxing_v2: ...@@ -214,7 +168,7 @@ if args.use_boxing_v2:
config.use_boxing_v2(True) config.use_boxing_v2(True)
@flow.function(config) @flow.global_function(config)
def PretrainJob(): def PretrainJob():
total_device_num = args.node_num * args.gpu_num_per_node total_device_num = args.node_num * args.gpu_num_per_node
batch_size = total_device_num * args.batch_size_per_device batch_size = total_device_num * args.batch_size_per_device
...@@ -256,8 +210,6 @@ def main(): ...@@ -256,8 +210,6 @@ def main():
flow.config.collective_boxing.nccl_fusion_threshold_mb(8) flow.config.collective_boxing.nccl_fusion_threshold_mb(8)
flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False) flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)
# if args.enable_auto_mixed_precision:
# flow.config.enable_auto_mixed_precision()
if args.node_num > 1: if args.node_num > 1:
nodes = [] nodes = []
...@@ -282,11 +234,10 @@ def main(): ...@@ -282,11 +234,10 @@ def main():
) )
speedometer = benchmark_util.BERTSpeedometer() speedometer = benchmark_util.BERTSpeedometer()
for step in range(args.warmup_iter_num + args.iter_num): for step in range(args.iter_num):
cb = speedometer.speedometer_cb( cb = speedometer.speedometer_cb(
step, step,
total_batch_size, total_batch_size,
args.warmup_iter_num,
args.iter_num, args.iter_num,
args.loss_print_every_n_iter, args.loss_print_every_n_iter,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册