train.py 1.9 KB
Newer Older
Z
zhoujun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
import os
import sys
import pathlib
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))

import paddle
import paddle.distributed as dist
from utils import Config, ArgsParser


def init_args():
    parser = ArgsParser()
    args = parser.parse_args()
    return args


def main(config, profiler_options):
    from models import build_model, build_loss
    from data_loader import get_dataloader
    from trainer import Trainer
    from post_processing import get_post_processing
    from utils import get_metric
    if paddle.device.cuda.device_count() > 1:
        dist.init_parallel_env()
        config['distributed'] = True
    else:
        config['distributed'] = False
    train_loader = get_dataloader(config['dataset']['train'],
                                  config['distributed'])
    assert train_loader is not None
    if 'validate' in config['dataset']:
        validate_loader = get_dataloader(config['dataset']['validate'], False)
    else:
        validate_loader = None
    criterion = build_loss(config['loss'])
    config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train'][
        'dataset']['args']['img_mode'] != 'GRAY' else 1
    model = build_model(config['arch'])
    # set @to_static for benchmark, skip this by default.
    post_p = get_post_processing(config['post_processing'])
    metric = get_metric(config['metric'])
    trainer = Trainer(
        config=config,
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        post_process=post_p,
        metric_cls=metric,
        validate_loader=validate_loader,
        profiler_options=profiler_options)
    trainer.train()


if __name__ == '__main__':
    args = init_args()
    assert os.path.exists(args.config_file)
    config = Config(args.config_file)
    config.merge_dict(args.opt)
    main(config.cfg, args.profiler_options)