diff --git a/tools/program.py b/tools/program.py index a1a7809048ab71b367ae489f5af61bd8a3022e93..511ee9dd1f12273eb773b6f2e29a3955940721ee 100755 --- a/tools/program.py +++ b/tools/program.py @@ -346,7 +346,10 @@ def train(config, lr_scheduler.step() # logger and visualdl - stats = {k: v.numpy().mean() for k, v in loss.items()} + stats = { + k: float(v) if v.shape == [] else v.numpy().mean() + for k, v in loss.items() + } stats['lr'] = lr train_stats.update(stats)