未验证 提交 3fad507e 编写于 作者: W wanghuancoder 提交者: GitHub

revert PR4893 and use Xreki‘s Code (#4902)

* fix ptb_dy time print for benchmark, test=develop

* use yiqun(Xreki)'s PR, test=develop

* add empty line, test=develop

* add empty line, test=develop
上级 16c1da5f
......@@ -34,22 +34,22 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TimeCostAverage(object):
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self.cnt == 0:
if self._cnt == 0:
return 0
return self.total_time / self.cnt
return self._total_time / self._cnt
def build_program(is_train, main_prog, startup_prog, args):
......@@ -245,16 +245,15 @@ def train(args):
compiled_train_prog = best_strategy_compiled(args, train_prog,
train_fetch_vars[0], exe)
batch_cost_avg = TimeCostAverage()
reader_cost_avg = TimeCostAverage()
#NOTE: this for benchmark
total_batch_num = 0
batch_cost_averager = TimeAverager()
reader_cost_averager = TimeAverager()
for pass_id in range(args.num_epochs):
if num_trainers > 1 and not args.use_dali:
imagenet_reader.set_shuffle_seed(pass_id + (
args.random_seed if args.random_seed else 0))
train_batch_id = 0
train_batch_time_record = []
train_batch_metrics_record = []
......@@ -264,43 +263,47 @@ def train(args):
if args.validate:
test_iter = test_data_loader()
t1 = time.time()
batch_start = time.time()
for batch in train_iter:
#NOTE: this is for benchmark
if args.max_iter and total_batch_num == args.max_iter:
return
t2 = time.time()
reader_cost = t2 - t1
reader_cost_avg.record(reader_cost)
reader_cost_averager.record(time.time() - batch_start)
train_batch_metrics = exe.run(compiled_train_prog,
feed=batch,
fetch_list=train_fetch_list)
t3 = time.time()
train_batch_elapse = t3 - t1
train_batch_time_record.append(train_batch_elapse)
batch_cost_avg.record(train_batch_elapse)
train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
# Record the time for ce and benchmark
train_batch_elapse = time.time() - batch_start
train_batch_time_record.append(train_batch_elapse)
batch_cost_averager.record(train_batch_elapse)
if trainer_id == 0:
ips = float(args.batch_size) / batch_cost_averager.get_average()
print_info(
"batch",
train_batch_metrics_avg,
batch_cost_avg.get_average(),
batch_cost_averager.get_average(),
pass_id,
train_batch_id,
args.print_step,
reader_cost=reader_cost_avg.get_average(),
ips=args.batch_size / batch_cost_avg.get_average())
reader_cost=reader_cost_averager.get_average(),
ips=ips)
sys.stdout.flush()
if train_batch_id % args.print_step == 0:
reader_cost_avg.reset()
batch_cost_avg.reset()
batch_cost_averager.reset()
reader_cost_averager.reset()
train_batch_id += 1
t1 = time.time()
#NOTE: this for benchmark profiler
total_batch_num = total_batch_num + 1
batch_start = time.time()
#NOTE: this for benchmark profiler
if args.is_profiler and pass_id == 0 and train_batch_id == args.print_step:
profiler.start_profiler("All")
elif args.is_profiler and pass_id == 0 and train_batch_id == args.print_step + 5:
......
......@@ -421,8 +421,8 @@ def print_info(info_mode,
print_step=1,
device_num=1,
class_dim=5,
reader_cost=0.0,
ips=0.0):
reader_cost=None,
ips=None):
"""print function
Args:
......@@ -435,34 +435,35 @@ def print_info(info_mode,
"""
#XXX: Use specific name to choose pattern, not the length of metrics.
if info_mode == "batch":
time_info_str = "batch_cost %.5f sec" % time_info
if reader_cost:
time_info_str += ", reader_cost %.5f sec" % reader_cost
if ips:
time_info_str += ", ips %.5f images/sec" % ips
if batch_id % print_step == 0:
#if isinstance(metrics,np.ndarray):
# train and mixup output
if len(metrics) == 2:
loss, lr = metrics
logger.info(
"[Pass {0}, train batch {1}] \tloss {2}, lr {3}, reader_cost: {5}, batch_cost: {4}, ips: {6}".
"[Pass {0}, train batch {1}] \tloss {2}, lr {3}, {4}".
format(pass_id, batch_id, "%.5f" % loss, "%.5f" % lr,
"%2.4f sec" % time_info, "%.5f sec" % reader_cost,
"%.5f images/sec" % ips))
time_info_str))
# train and no mixup output
elif len(metrics) == 4:
loss, acc1, acc5, lr = metrics
logger.info(
"[Pass {0}, train batch {1}] \tloss {2}, acc1 {3}, acc{7} {4}, lr {5}, reader_cost: {8}, batch_cost: {6}, ips: {9}".
"[Pass {0}, train batch {1}] \tloss {2}, acc1 {3}, acc{7} {4}, lr {5}, {6}".
format(pass_id, batch_id, "%.5f" % loss, "%.5f" % acc1,
"%.5f" % acc5, "%.5f" % lr, "%2.4f sec" % time_info,
min(class_dim, 5), "%.5f sec" % reader_cost,
"%.5f images/sec" % ips))
"%.5f" % acc5, "%.5f" % lr, time_info_str,
min(class_dim, 5)))
# test output
elif len(metrics) == 3:
loss, acc1, acc5 = metrics
logger.info(
"[Pass {0}, test batch {1}] \tloss {2}, acc1 {3}, acc{6} {4}, reader_cost: {7}, batch_cost: {5}, ips: {8}".
"[Pass {0}, test batch {1}] \tloss {2}, acc1 {3}, acc{6} {4}, {5}".
format(pass_id, batch_id, "%.5f" % loss, "%.5f" % acc1,
"%.5f" % acc5, "%2.4f sec" % time_info,
min(class_dim, 5), "%.5f sec" % reader_cost,
"%.5f images/sec" % ips))
"%.5f" % acc5, time_info_str, min(class_dim, 5)))
else:
raise Exception(
"length of metrics {} is not implemented, It maybe caused by wrong format of build_program_output".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册