未验证 提交 5b895d74 编写于 作者: F Feng Wang 提交者: GitHub

feat(exp): get_trainer method, add pre-commit (#1263)

上级 68408b40
repos:
- repo: https://github.com/pycqa/flake8
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
hooks:
- id: check-added-large-files
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-json
- id: check-yaml
args: ["--unsafe"]
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/jorisroovers/gitlint
rev: v0.15.1
hooks:
- id: gitlint
- repo: https://github.com/pycqa/isort
rev: 4.3.21
hooks:
- id: isort
- repo: https://github.com/PyCQA/autoflake
rev: v1.4
hooks:
- id: autoflake
name: Remove unused variables and imports
entry: autoflake
language: python
args:
[
"--in-place",
"--remove-all-unused-imports",
"--remove-unused-variables",
"--expand-star-imports",
"--ignore-init-module-imports",
]
files: \.py$
......@@ -10,8 +10,8 @@ from loguru import logger
import torch
import torch.backends.cudnn as cudnn
from yolox.core import Trainer, launch
from yolox.exp import get_exp
from yolox.core import launch
from yolox.exp import Exp, get_exp
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
......@@ -97,7 +97,7 @@ def make_parser():
@logger.catch
def main(exp, args):
def main(exp: Exp, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
......@@ -113,7 +113,7 @@ def main(exp, args):
configure_omp()
cudnn.benchmark = True
trainer = Trainer(exp, args)
trainer = exp.get_trainer(args)
trainer.train()
......
......@@ -12,6 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from yolox.data import DataPrefetcher
from yolox.exp import Exp
from yolox.utils import (
MeterBuffer,
ModelEMA,
......@@ -33,7 +34,7 @@ from yolox.utils import (
class Trainer:
def __init__(self, exp, args):
def __init__(self, exp: Exp, args):
# init function only defines some basic attr, other attrs like model, optimizer are built in
# before_train methods.
self.exp = exp
......
......@@ -127,9 +127,7 @@ class Exp(BaseExp):
self.model.train()
return self.model
def get_data_loader(
self, batch_size, is_distributed, no_aug=False, cache_img=False
):
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
from yolox.data import (
COCODataset,
TrainTransform,
......@@ -314,5 +312,11 @@ class Exp(BaseExp):
)
return evaluator
def get_trainer(self, args):
from yolox.core import Trainer
trainer = Trainer(self, args)
# NOTE: trainer shouldn't be an attribute of exp object
return trainer
def eval(self, model, evaluator, is_distributed, half=False):
return evaluator.evaluate(model, is_distributed, half)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册