提交 6ad919f5 编写于 作者: S ShawnXuan

support top k

上级 d49b924b
......@@ -12,7 +12,7 @@ parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)
from util import Snapshot, Summary, InitNodes, StopWatch
from util import Snapshot, Summary, InitNodes, StopWatch, Metric
#from dali_util import get_rec_iter
import ofrecord_util
from job_function_util import get_train_config, get_val_config
......@@ -27,9 +27,9 @@ train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
(C, H, W) = args.image_shape
epoch_size = math.ceil(args.num_examples / train_batch_size)
num_val_steps = args.num_val_examples / val_batch_size
num_val_steps = int(args.num_val_examples / val_batch_size)
summary = Summary(args.log_dir, args)
#summary = Summary(args.log_dir, args)
timer = StopWatch()
model_dict = {
......@@ -56,8 +56,8 @@ def TrainNet():
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
#loss = flow.math.reduce_mean(loss)
flow.losses.add_loss(loss)
softmax = flow.nn.softmax(logits)
outputs = {"loss": loss, "softmax":softmax, "labels": labels}
predictions = flow.nn.softmax(logits)
outputs = {"loss": loss, "predictions":predictions, "labels": labels}
return outputs
......@@ -72,56 +72,11 @@ def InferenceNet():
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](images)
softmax = flow.nn.softmax(logits)
outputs = {"softmax":softmax, "labels": labels}
predictions = flow.nn.softmax(logits)
outputs = {"predictions":predictions, "labels": labels}
return outputs#(softmax, labels)
def acc_acc(step, predictions):
classfications = np.argmax(predictions['softmax'].ndarray(), axis=1)
labels = predictions['labels'].reshape(-1)
if step == 0:
main.correct = 0.0
main.total = 0.0
else:
main.correct += np.sum(classfications == labels);
main.total += len(labels)
def train_callback(epoch, step):
def callback(train_outputs):
acc_acc(step, train_outputs)
loss = train_outputs['loss'].mean()
summary.scalar('loss', loss, step)
#summary.scalar('learning_rate', train_outputs['lr'], step)
if (step-1) % args.loss_print_every_n_iter == 0:
throughput = args.loss_print_every_n_iter * train_batch_size / timer.split()
accuracy = main.correct/main.total
print("epoch {}, iter {}, loss: {:.6f}, accuracy: {:.6f}, samples/s: {:.3f}".format(
epoch, step-1, loss, accuracy, throughput))
summary.scalar('train_accuracy', accuracy, step)
main.correct = 0.0
main.total = 0.0
return callback
def do_predictions(epoch, predict_step, predictions):
acc_acc(predict_step, predictions)
if predict_step + 1 == num_val_steps:
assert main.total > 0
summary.scalar('top1_accuracy', main.correct/main.total, epoch)
#summary.scalar('top1_correct', main.correct, epoch)
#summary.scalar('total_val_images', main.total, epoch)
print("epoch {}, top 1 accuracy: {:.6f}, time: {:.2f}".format(epoch,
main.correct/main.total, timer.split()))
def predict_callback(epoch, predict_step):
def callback(predictions):
do_predictions(epoch, predict_step, predictions)
return callback
def main():
InitNodes(args)
......@@ -130,23 +85,22 @@ def main():
snapshot = Snapshot(args.model_save_dir, args.model_load_dir)
timer.start()
for epoch in range(args.num_epochs):
tic = time.time()
print('Starting epoch {} at {:.2f}'.format(epoch, tic))
metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
batch_size=train_batch_size, loss_key='loss')
for i in range(epoch_size):
TrainNet().async_get(train_callback(epoch, i))
# if i > 30:#debug
# break
TrainNet().async_get(metric.metric_cb(epoch, i))
if i > 40:#debug
break
#break
print('epoch {} training time: {:.2f}'.format(epoch, time.time() - tic))
if args.val_data_dir:
tic = time.time()
metric = Metric(desc='validataion', calculate_batches=num_val_steps,
batch_size=val_batch_size)
for i in range(num_val_steps):
InferenceNet().async_get(predict_callback(epoch, i))
InferenceNet().async_get(metric.metric_cb(epoch, i))
summary.save()
snapshot.save('epoch_{}'.format(epoch+1))
#summary.save()
#snapshot.save('epoch_{}'.format(epoch+1))
if __name__ == "__main__":
......
......@@ -4,6 +4,7 @@ from __future__ import print_function
import os
import time
import numpy as np
import pandas as pd
from datetime import datetime
import oneflow as flow
......@@ -79,3 +80,63 @@ class StopWatch:
def duration(self):
return self.stop_time - self.start_time
def match_top_k(predictions, labels, top_k=1):
max_k_preds = predictions.argsort(axis=1)[:, -top_k:][:, ::-1]
match_array = np.logical_or.reduce(max_k_preds==labels.reshape((-1, 1)), axis=1)
num_matched = match_array.sum()
#topk_acc_score = match_array.sum().astype(float) / match_array.shape[0]
return num_matched, match_array.shape[0]
class Metric():
def __init__(self, desc='train', calculate_batches=-1, batch_size=256, top_k=5,
prediction_key='predictions', label_key='labels', loss_key=None):
self.desc = desc
self.calculate_batches = calculate_batches
self.top_k = top_k
self.prediction_key = prediction_key
self.label_key = label_key
self.loss_key = loss_key
if loss_key:
self.fmt = "{}: epoch {}, iter {}, loss: {:.6f}, top_1: {:.6f}, top_k: {:.6f}, samples/s: {:.3f}"
else:
self.fmt = "{}: epoch {}, iter {}, top_1: {:.6f}, top_k: {:.6f}, samples/s: {:.3f}"
self.timer = StopWatch()
self.timer.start()
self._clear()
def _clear(self):
self.top_1_num_matched = 0
self.top_k_num_matched = 0
self.num_samples = 0.0
def metric_cb(self, epoch, step):
def callback(outputs):
if step == 0: self._clear()
num_matched, num_samples = match_top_k(outputs[self.prediction_key],
outputs[self.label_key])
self.top_1_num_matched += num_matched
self.num_samples += num_samples
num_matched, _ = match_top_k(outputs[self.prediction_key],
outputs[self.label_key], self.top_k)
self.top_k_num_matched += num_matched
if (step+1) % self.calculate_batches == 0:
throughput = self.num_samples / self.timer.split()
top_1_accuracy = self.top_1_num_matched / self.num_samples
top_k_accuracy = self.top_k_num_matched / self.num_samples
if self.loss_key:
loss = outputs[self.loss_key].mean()
print(self.fmt.format(self.desc, epoch, step, loss, top_1_accuracy,
top_k_accuracy, throughput))
#summary.scalar('loss', loss, step)
else:
print(self.fmt.format(self.desc, epoch, step, top_1_accuracy, top_k_accuracy,
throughput))
#summary.scalar('train_accuracy', accuracy, step)
self._clear()
return callback
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册