train.py 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import multiprocessing as mp
import os
import shutil
import time
import json
import numpy as np
import cv2

import megengine as mge
import megengine.data as data
import megengine.data.transform as T
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.optimizer as optim

import official.vision.keypoints.models as M
from official.vision.keypoints.transforms import (
    RandomBoxAffine,
    RandomHorizontalFlip,
    HalfBodyTransform,
    ExtendBoxes,
)
from official.vision.keypoints.dataset import COCOJoints, HeatmapCollator
from official.vision.keypoints.config import Config as cfg

logger = mge.get_logger(__name__)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-a",
        "--arch",
        default="simplebaseline_res50",
        type=str,
        choices=[
            "simplebaseline_res50",
            "simplebaseline_res101",
            "simplebaseline_res152",
G
greatlog 已提交
50
            "mspn_4stage",
51 52 53 54
        ],
    )
    parser.add_argument("-s", "--save", default="/data/models", type=str)

G
greatlog 已提交
55
    parser.add_argument("--resume", default=None, type=str)
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

    parser.add_argument("--multi_scale_supervision", default=True, type=bool)

    parser.add_argument("-n", "--ngpus", default=8, type=int)
    parser.add_argument("-w", "--workers", default=8, type=int)
    parser.add_argument("--report-freq", default=10, type=int)

    args = parser.parse_args()

    model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0], cfg.input_shape[1])
    save_dir = os.path.join(args.save, model_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus

    if world_size > 1:
        # scale learning rate by number of gpus
G
greatlog 已提交
75
        cfg.initial_lr *= world_size
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        # start distributed training, dispatch sub-processes
        processes = []
        for rank in range(world_size):
            p = mp.Process(target=worker, args=(rank, world_size, args))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()
    else:
        worker(0, 1, args)


def worker(rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0], cfg.input_shape[1])
    save_dir = os.path.join(args.save, model_name)

G
greatlog 已提交
104
    model = getattr(M, args.arch)()
105 106
    model.train()
    start_epoch = 0
G
greatlog 已提交
107 108
    if args.resume is not None:
        file = mge.load(args.resume)
109 110 111 112
        model.load_state_dict(file["state_dict"])
        start_epoch = file["epoch"]

    optimizer = optim.Adam(
G
greatlog 已提交
113 114 115
        model.parameters(requires_grad=True),
        lr=cfg.initial_lr,
        weight_decay=cfg.weight_decay,
116 117 118
    )
    # Build train datasets
    logger.info("preparing dataset..")
G
greatlog 已提交
119 120 121
    ann_file = os.path.join(
        cfg.data_root, "annotations", "person_keypoints_train2017.json"
    )
122
    train_dataset = COCOJoints(
G
greatlog 已提交
123 124 125
        cfg.data_root,
        ann_file,
        image_set="train2017",
126 127 128
        order=("image", "keypoints", "boxes", "info"),
    )
    train_sampler = data.RandomSampler(
G
greatlog 已提交
129
        train_dataset, batch_size=cfg.batch_size, drop_last=True
130 131
    )

G
greatlog 已提交
132
    transforms = [T.Normalize(mean=cfg.img_mean, std=cfg.img_std)]
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    if cfg.half_body_transform:
        transforms.append(
            HalfBodyTransform(
                cfg.upper_body_ids, cfg.lower_body_ids, cfg.prob_half_body
            )
        )
    if cfg.extend_boxes:
        transforms.append(
            ExtendBoxes(cfg.x_ext, cfg.y_ext, cfg.input_shape[1] / cfg.input_shape[0])
        )
    transforms += [
        RandomHorizontalFlip(0.5, keypoint_flip_order=cfg.keypoint_flip_order)
    ]
    transforms += [
        RandomBoxAffine(
            degrees=cfg.rotate_range,
            scale=cfg.scale_range,
            output_shape=cfg.input_shape,
            rotate_prob=cfg.rotation_prob,
            scale_prob=cfg.scale_prob,
        )
    ]
    transforms += [T.ToMode()]

    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=args.workers,
        transform=T.Compose(transforms=transforms, order=train_dataset.order,),
        collator=HeatmapCollator(
            cfg.input_shape,
            cfg.output_shape,
            cfg.keypoint_num,
G
greatlog 已提交
166
            cfg.heat_thr,
167 168 169 170 171 172
            cfg.heat_kernel if args.multi_scale_supervision else cfg.heat_kernel[-1:],
            cfg.heat_range,
        ),
    )

    # Start training
G
greatlog 已提交
173
    for epoch in range(start_epoch, cfg.epochs):
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        loss = train(model, train_queue, optimizer, args, epoch=epoch)
        logger.info("Epoch %d Train %.6f ", epoch, loss)

        if rank == 0:  # save checkpoint
            mge.save(
                {"epoch": epoch + 1, "state_dict": model.state_dict(),},
                os.path.join(save_dir, "epoch_{}.pkl".format(epoch)),
            )


def train(model, data_queue, optimizer, args, epoch=0):
    @jit.trace(symbolic=True, opt_level=2)
    def train_func():
        loss = model.calc_loss()
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size()
        return loss

    avg_loss = 0
    total_time = 0

    t = time.time()
    for step, mini_batch in enumerate(data_queue):

        for param_group in optimizer.param_groups:
            current_step = epoch * len(data_queue) + step
            if current_step < cfg.warm_epochs * len(data_queue):
                lr_factor = cfg.lr_ratio + (
                    1 - cfg.lr_ratio
                ) * current_step / cfg.warm_epochs / len(data_queue)
            else:
                lr_factor = 1 - (current_step - len(data_queue) * cfg.warm_epochs) / (
G
greatlog 已提交
207
                    len(data_queue) * (cfg.epochs - cfg.warm_epochs)
208 209
                )

G
greatlog 已提交
210
            lr = cfg.initial_lr * lr_factor
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
            param_group["lr"] = lr

        lr = optimizer.param_groups[0]["lr"]
        model.inputs["image"].set_value(mini_batch["data"])
        model.inputs["heatmap"].set_value(mini_batch["heatmap"])
        model.inputs["heat_valid"].set_value(mini_batch["heat_valid"])

        optimizer.zero_grad()
        loss = train_func()
        optimizer.step()

        avg_loss = (avg_loss * step + loss.numpy().item()) / (step + 1)
        total_time += time.time() - t
        t = time.time()

        if step % args.report_freq == 0 and dist.get_rank() == 0:
            logger.info(
                "Epoch {} Step {}, LR {:.6f} Loss {:.6f} Elapsed Time {:.3f}s".format(
                    epoch, step, lr, loss.numpy().item(), total_time
                )
            )

    return avg_loss


if __name__ == "__main__":
    main()