提交 adb76834 编写于 作者: S ShawnXuan

merge dev_python_py3

上级 54a99d60
......@@ -10,6 +10,10 @@ class StopWatch:
self.start_time = time.time()
self.last_split = self.start_time
def set_start(self, val):
self.start_time = val
self.last_split = self.start_time
def split(self):
now = time.time()
duration = now - self.last_split
......@@ -29,89 +33,56 @@ class CNNSpeedometer:
self.throughoutput_list = []
def speedometer_cb(
self, step, total_batch_size, warmup_num, iter_num, loss_print_every_n_iter
self,
step,
start_time,
total_batch_size,
skip_iter_num,
iter_num,
loss_print_every_n_iter,
):
def callback(train_loss):
if step < warmup_num:
print(
"Runing warm up for {}/{} iterations.".format(step + 1, warmup_num)
)
if (step + 1) >= warmup_num:
self.watch.start()
print("Start trainning.")
else:
train_step = step - warmup_num
assert skip_iter_num >= 0
if skip_iter_num == 0 and step == 0:
self.watch.set_start(start_time)
print("Start trainning without any skipping iteration.")
if (train_step + 1) % loss_print_every_n_iter == 0:
loss = train_loss.mean()
duration = self.watch.split() / loss_print_every_n_iter
images_per_sec = (
total_batch_size / duration
)
if step < skip_iter_num:
if step == 0:
print(
"iter {}, loss: {:.3f}, speed: {:.3f}(sec/batch), {:.3f}(images/sec)".format(
train_step, loss, duration, images_per_sec
"Skipping {} iterations for benchmark purpose.".format(
skip_iter_num
)
)
self.throughoutput_list.append(images_per_sec)
if (train_step + 1) == iter_num:
self.watch.stop()
print("-".ljust(66, "-"))
print("average speed: {:.3f}(images/sec)".format(np.mean(self.throughoutput_list)))
print("-".ljust(66, "-"))
return callback
class BERTSpeedometer:
def __init__(self):
self.watch = StopWatch()
def speedometer_cb(
self, step, total_batch_size, warmup_num, iter_num, loss_print_every_n_iter
):
def callback(train_loss):
if step < warmup_num:
print(
"Runing warm up for {}/{} iterations.".format(step + 1, warmup_num)
)
if (step + 1) == warmup_num:
if (step + 1) == skip_iter_num:
self.watch.start()
print("Start trainning.")
else:
train_step = step - warmup_num
train_step = step - skip_iter_num
if (train_step + 1) % loss_print_every_n_iter == 0:
total_loss = train_loss[0].mean()
mlm_loss = train_loss[1].mean()
nsp_loss = train_loss[2].mean()
loss = train_loss.mean()
duration = self.watch.split()
sentences_per_sec = (
total_batch_size * loss_print_every_n_iter / duration
avg_elapse_time_per_iter = (
self.watch.split() / loss_print_every_n_iter
)
samples_per_sec = total_batch_size / avg_elapse_time_per_iter
print(
"iter {}, total_loss: {:.3f}, mlm_loss: {:.3f}, nsp_loss: {:.3f}, speed: {:.3f}(sec/batch), {:.3f}(sentences/sec)".format(
train_step,
total_loss,
mlm_loss,
nsp_loss,
duration,
sentences_per_sec,
"iter {}, loss: {:.3f}, speed: {:.3f}(sec/batch), {:.3f}(images/sec)".format(
train_step, loss, avg_elapse_time_per_iter, samples_per_sec
)
)
self.throughoutput_list.append(samples_per_sec)
if (train_step + 1) == iter_num:
self.watch.stop()
totoal_duration = self.watch.duration()
avg_sentences_per_sec = (
total_batch_size * iter_num / totoal_duration
)
avg_samples_per_sec = total_batch_size * iter_num / totoal_duration
print("-".ljust(66, "-"))
print(
"average speed: {:.3f}(sentences/sec)".format(
avg_sentences_per_sec
"average speed: {:.3f}(images/sec), new_cal_method: {:.3f}(images/sec)".format(
avg_samples_per_sec, np.mean(self.throughoutput_list)
)
)
print("-".ljust(66, "-"))
......
......@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import os
import time
import argparse
from datetime import datetime
......@@ -48,11 +49,11 @@ parser.add_argument(
"--iter_num", type=int, default=10, required=False, help="total iterations to run"
)
parser.add_argument(
"--warmup_iter_num",
"--skip_iter_num",
type=int,
default=0,
required=False,
help="total iterations to run",
help="number of skipping iterations for benchmark purpose.",
)
parser.add_argument(
"--data_dir", type=str, default=None, required=False, help="dataset directory"
......@@ -137,23 +138,23 @@ optimizer_dict = {
# "warmup_conf": {"linear_conf": {"warmup_batches":10000, "start_multiplier":0}},
func_config = flow.FunctionConfig()
func_config.default_distribute_strategy(flow.distribute.consistent_strategy())
func_config.train.primary_lr(args.learning_rate)
func_config.default_data_type(flow.float)
func_config.train.model_update_conf(optimizer_dict[args.optimizer])
func_config.disable_all_reduce_sequence(True)
func_config.all_reduce_group_min_mbyte(8)
func_config.all_reduce_group_num(128)
@flow.function
def TrainNet():
flow.config.train.primary_lr(args.learning_rate)
flow.config.disable_all_reduce_sequence(True)
# flow.config.all_reduce_lazy_ratio(0)
if args.weight_l2:
func_config.train.weight_l2(args.weight_l2)
# flow.config.enable_nccl_hierarchical_all_reduce(True)
# flow.config.cudnn_buf_limit_mbyte(2048)
# flow.config.concurrency_width(2)
flow.config.all_reduce_group_num(128)
flow.config.all_reduce_group_min_mbyte(8)
flow.config.gpu_device_num(args.gpu_num_per_node)
flow.config.train.model_update_conf(optimizer_dict[args.optimizer])
if args.weight_l2:
flow.config.train.weight_l2(args.weight_l2)
@flow.function(func_config)
def TrainNet():
total_device_num = args.node_num * args.gpu_num_per_node
batch_size = total_device_num * args.batch_size_per_device
......@@ -189,12 +190,8 @@ def main():
print("{} = {}".format(arg, getattr(args, arg)))
print("-".ljust(66, "-"))
print("Time stamp: {}".format(str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
flow.config.default_data_type(flow.float)
flow.config.gpu_device_num(args.gpu_num_per_node)
flow.env.grpc_use_no_signal()
flow.env.log_dir(args.log_dir)
# flow.config.enable_inplace(False)
# flow.config.ctrl_port(12140)
if args.node_num > 1:
nodes = []
......@@ -218,12 +215,14 @@ def main():
args.node_num * args.gpu_num_per_node * args.batch_size_per_device
)
speedometer = benchmark_util.CNNSpeedometer()
start_time = time.time()
for step in range(args.warmup_iter_num + args.iter_num):
for step in range(args.skip_iter_num + args.iter_num):
cb = speedometer.speedometer_cb(
step,
start_time,
total_batch_size,
args.warmup_iter_num,
args.skip_iter_num,
args.iter_num,
args.loss_print_every_n_iter,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册