未验证 提交 996694b5 编写于 作者: M Minato 提交者: GitHub

add ips and fps metric to wandb callback (#7542)

上级 f8e21c8f
...@@ -339,6 +339,7 @@ class WandbCallback(Callback): ...@@ -339,6 +339,7 @@ class WandbCallback(Callback):
self.run.define_metric("eval/*", step_metric="epoch") self.run.define_metric("eval/*", step_metric="epoch")
self.best_ap = -1000. self.best_ap = -1000.
self.fps = []
@property @property
def run(self): def run(self):
...@@ -360,6 +361,7 @@ class WandbCallback(Callback): ...@@ -360,6 +361,7 @@ class WandbCallback(Callback):
last_epoch, last_epoch,
ema_model=None, ema_model=None,
ap=None, ap=None,
fps=None,
tags=None): tags=None):
if dist.get_world_size() < 2 or dist.get_rank() == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
model_path = os.path.join(save_dir, save_name) model_path = os.path.join(save_dir, save_name)
...@@ -367,6 +369,10 @@ class WandbCallback(Callback): ...@@ -367,6 +369,10 @@ class WandbCallback(Callback):
metadata["last_epoch"] = last_epoch metadata["last_epoch"] = last_epoch
if ap: if ap:
metadata["ap"] = ap metadata["ap"] = ap
if fps:
metadata["fps"] = fps
if ema_model is None: if ema_model is None:
ema_artifact = self.wandb.Artifact( ema_artifact = self.wandb.Artifact(
name="ema_model-{}".format(self.run.id), name="ema_model-{}".format(self.run.id),
...@@ -398,7 +404,25 @@ class WandbCallback(Callback): ...@@ -398,7 +404,25 @@ class WandbCallback(Callback):
training_status = status['training_staus'].get() training_status = status['training_staus'].get()
for k, v in training_status.items(): for k, v in training_status.items():
training_status[k] = float(v) training_status[k] = float(v)
# calculate ips, data_cost, batch_cost
batch_time = status['batch_time']
data_time = status['data_time']
batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
))]['batch_size']
ips = float(batch_size) / float(batch_time.avg)
data_cost = float(data_time.avg)
batch_cost = f
metrics = {"train/" + k: v for k, v in training_status.items()} metrics = {"train/" + k: v for k, v in training_status.items()}
metrics["train/ips"] = ips
metrics["train/data_cost"] = data_cost
metrics["train/batch_cost"] = batch_cost
self.fps.append(ips)
self.run.log(metrics) self.run.log(metrics)
def on_epoch_end(self, status): def on_epoch_end(self, status):
...@@ -407,6 +431,9 @@ class WandbCallback(Callback): ...@@ -407,6 +431,9 @@ class WandbCallback(Callback):
save_name = None save_name = None
if dist.get_world_size() < 2 or dist.get_rank() == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train': if mode == 'train':
fps = sum(self.fps) / len(self.fps)
self.fps = []
end_epoch = self.model.cfg.epoch end_epoch = self.model.cfg.epoch
if ( if (
epoch_id + 1 epoch_id + 1
...@@ -420,13 +447,21 @@ class WandbCallback(Callback): ...@@ -420,13 +447,21 @@ class WandbCallback(Callback):
save_name, save_name,
epoch_id + 1, epoch_id + 1,
self.model.use_ema, self.model.use_ema,
fps=fps,
tags=tags) tags=tags)
if mode == 'eval': if mode == 'eval':
sample_num = status['sample_num']
cost_time = status['cost_time']
fps = sample_num / cost_time
merged_dict = {} merged_dict = {}
for metric in self.model._metrics: for metric in self.model._metrics:
for key, map_value in metric.get_results().items(): for key, map_value in metric.get_results().items():
merged_dict["eval/{}-mAP".format(key)] = map_value[0] merged_dict["eval/{}-mAP".format(key)] = map_value[0]
merged_dict["epoch"] = status["epoch_id"] merged_dict["epoch"] = status["epoch_id"]
merged_dict["eval/fps"] = sample_num / cost_time
self.run.log(merged_dict) self.run.log(merged_dict)
if 'save_best_model' in status and status['save_best_model']: if 'save_best_model' in status and status['save_best_model']:
...@@ -457,6 +492,7 @@ class WandbCallback(Callback): ...@@ -457,6 +492,7 @@ class WandbCallback(Callback):
last_epoch=epoch_id + 1, last_epoch=epoch_id + 1,
ema_model=self.model.use_ema, ema_model=self.model.use_ema,
ap=abs(self.best_ap), ap=abs(self.best_ap),
fps=fps,
tags=tags) tags=tags)
def on_train_end(self, status): def on_train_end(self, status):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册