未验证 提交 bb9185c0 编写于 作者: F Feng Wang 提交者: GitHub

fix(utils,exp): logger compat issue and exp check (#1618)

fix(utils,exp): logger compat issue and exp check (#1618)
上级 618fd8c0
......@@ -11,7 +11,7 @@ import torch
import torch.backends.cudnn as cudnn
from yolox.core import launch
from yolox.exp import Exp, get_exp
from yolox.exp import Exp, check_exp_value, get_exp
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
......@@ -123,6 +123,7 @@ if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
check_exp_value(exp)
if not args.experiment_name:
args.experiment_name = exp.exp_name
......
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
from .base_exp import BaseExp
from .build import get_exp
from .yolox_base import Exp
from .yolox_base import Exp, check_exp_value
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import ast
......@@ -66,7 +65,7 @@ class BaseExp(metaclass=ABCMeta):
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
def merge(self, cfg_list):
assert len(cfg_list) % 2 == 0
assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}"
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
# only update value with same key
if hasattr(self, k):
......@@ -74,7 +73,7 @@ class BaseExp(metaclass=ABCMeta):
src_type = type(src_value)
# pre-process input if source type is list or tuple
if isinstance(src_value, List) or isinstance(src_value, Tuple):
if isinstance(src_value, (List, Tuple)):
v = v.strip("[]()")
v = [t.strip() for t in v.split(",")]
......
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import os
......@@ -11,6 +10,8 @@ import torch.nn as nn
from .base_exp import BaseExp
__all__ = ["Exp", "check_exp_value"]
class Exp(BaseExp):
def __init__(self):
......@@ -350,3 +351,8 @@ class Exp(BaseExp):
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
def check_exp_value(exp: Exp):
h, w = exp.input_size
assert h % 32 == 0 and w % 32 == 0, "input size must be multiples of 32"
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import inspect
......@@ -58,7 +57,8 @@ class StreamToLoguru:
sys.__stdout__.write(buf)
def flush(self):
pass
# flush is related with CPR(cursor position report) in terminal
return sys.__stdout__.flush()
def isatty(self):
# when using colab, jax is installed by default and issue like
......@@ -66,7 +66,11 @@ class StreamToLoguru:
# due to missing attribute like`isatty`.
# For more details, checked the following link:
# https://github.com/google/jax/blob/10720258ea7fb5bde997dfa2f3f71135ab7a6733/jax/_src/pretty_printer.py#L54 # noqa
return True
return sys.__stdout__.isatty()
def fileno(self):
# To solve the issue when using debug tools like pdb
return sys.__stdout__.fileno()
def redirect_sys_output(log_level="INFO"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册