提交 4592da8e 编写于 作者: S ShawnXuan

support save summary

上级 57fc81e4
......@@ -12,7 +12,7 @@ parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)
from util import Snapshot, Summary, InitNodes, StopWatch, Metric
from util import Snapshot, Summary, InitNodes, Metric
#from dali_util import get_rec_iter
import ofrecord_util
from job_function_util import get_train_config, get_val_config
......@@ -29,8 +29,6 @@ val_batch_size = total_device_num * args.val_batch_size_per_device
epoch_size = math.ceil(args.num_examples / train_batch_size)
num_val_steps = int(args.num_val_examples / val_batch_size)
#summary = Summary(args.log_dir, args)
timer = StopWatch()
model_dict = {
"resnet50": resnet_model.resnet50,
......@@ -83,10 +81,12 @@ def main():
flow.env.grpc_use_no_signal()
flow.env.log_dir(args.log_dir)
summary = Summary(args.log_dir, args)
snapshot = Snapshot(args.model_save_dir, args.model_load_dir)
for epoch in range(args.num_epochs):
for epoch in range(3):#args.num_epochs):
metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
summary=summary, save_summary_steps=epoch_size,
batch_size=train_batch_size, loss_key='loss')
for i in range(epoch_size):
TrainNet().async_get(metric.metric_cb(epoch, i))
......@@ -94,8 +94,8 @@ def main():
# break
#break
if args.val_data_dir:
metric = Metric(desc='validataion', calculate_batches=num_val_steps,
batch_size=val_batch_size)
metric = Metric(desc='validataion', calculate_batches=num_val_steps, summary=summary,
save_summary_steps=num_val_steps, batch_size=val_batch_size)
for i in range(num_val_steps):
InferenceNet().async_get(metric.metric_cb(epoch, i))
......
......@@ -22,7 +22,7 @@ def InitNodes(args):
flow.env.machine(nodes)
class Snapshot:
class Snapshot(object):
def __init__(self, model_save_dir, model_load_dir):
self._model_save_dir = model_save_dir
self._check_point = flow.train.CheckPoint()
......@@ -42,25 +42,24 @@ class Snapshot:
self._check_point.save(snapshot_save_path)
class Summary():
class Summary(object):
def __init__(self, log_dir, config):
self._log_dir = log_dir
self._metrics = pd.DataFrame({"iter": 0, "legend": "cfg", "note": str(config)}, index=[0])
self._metrics = pd.DataFrame({"epoch":0, "iter": 0, "legend": "cfg", "note": str(config)}, index=[0])
def scalar(self, legend, value, step=-1):
def scalar(self, legend, value, epoch, step=-1):
# TODO: support rank(which device/gpu)
df = pd.DataFrame(
{"iter": step, "legend": legend, "value": value, "rank": 0, "time": time.time()},
{"epoch": epoch, "iter": step, "legend": legend, "value": value, "rank": 0},
index=[0])
self._metrics = pd.concat([self._metrics, df], axis=0, sort=False)
def save(self):
save_path = os.path.join(self._log_dir, "summary.csv")
self._metrics.to_csv(save_path, index=False)
print("saved: {}".format(save_path))
class StopWatch:
class StopWatch(object):
def __init__(self):
pass
......@@ -85,12 +84,16 @@ def match_top_k(predictions, labels, top_k=1):
max_k_preds = predictions.argsort(axis=1)[:, -top_k:][:, ::-1]
match_array = np.logical_or.reduce(max_k_preds==labels.reshape((-1, 1)), axis=1)
num_matched = match_array.sum()
#topk_acc_score = match_array.sum().astype(float) / match_array.shape[0]
return num_matched, match_array.shape[0]
class Metric():
def __init__(self, desc='train', calculate_batches=-1, batch_size=256, top_k=5,
prediction_key='predictions', label_key='labels', loss_key=None):
class Metric(object):
def __init__(self, summary=None, save_summary_steps=-1, desc='train', calculate_batches=-1,
batch_size=256, top_k=5, prediction_key='predictions', label_key='labels',
loss_key=None):
self.summary = summary
self.save_summary = isinstance(self.summary, Summary)
self.save_summary_steps = save_summary_steps
self.desc = desc
self.calculate_batches = calculate_batches
self.top_k = top_k
......@@ -105,7 +108,7 @@ class Metric():
self.timer = StopWatch()
self.timer.start()
self._clear()
def _clear(self):
self.top_1_num_matched = 0
self.top_k_num_matched = 0
......@@ -122,21 +125,31 @@ class Metric():
outputs[self.label_key], self.top_k)
self.top_k_num_matched += num_matched
if (step+1) % self.calculate_batches == 0:
if (step + 1) % self.calculate_batches == 0:
throughput = self.num_samples / self.timer.split()
top_1_accuracy = self.top_1_num_matched / self.num_samples
top_k_accuracy = self.top_k_num_matched / self.num_samples
if self.loss_key:
loss = outputs[self.loss_key].mean()
print(self.fmt.format(self.desc, epoch, step, loss, top_1_accuracy,
print(self.fmt.format(self.desc, epoch, step + 1, loss, top_1_accuracy,
top_k_accuracy, throughput))
#summary.scalar('loss', loss, step)
if self.save_summary:
self.summary.scalar(self.desc+"_" + self.loss_key, loss, epoch, step)
else:
print(self.fmt.format(self.desc, epoch, step, top_1_accuracy, top_k_accuracy,
throughput))
print(self.fmt.format(self.desc, epoch, step + 1, top_1_accuracy,
top_k_accuracy, throughput))
#summary.scalar('train_accuracy', accuracy, step)
self._clear()
if self.save_summary:
self.summary.scalar(self.desc + "_throughput", throughput, epoch, step)
self.summary.scalar(self.desc + "_top_1", top_1_accuracy, epoch, step)
self.summary.scalar(self.desc + "_top_{}".format(self.top_k), top_k_accuracy,
epoch, step)
if self.save_summary:
if (step + 1) % self.save_summary_steps == 0:
self.summary.save()
return callback
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册