infer.py 2.7 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7
import argparse
import os
import os.path as osp
import cv2
import numpy as np
import tqdm

8 9 10
import utils
import models
import transforms
W
wuyefeilin 已提交
11 12 13 14 15 16


def parse_args():
    parser = argparse.ArgumentParser(
        description='HumanSeg inference and visualization')
    parser.add_argument(
17 18 19
        '--model_dir',
        dest='model_dir',
        help='Model path for inference',
W
wuyefeilin 已提交
20 21 22 23
        type=str)
    parser.add_argument(
        '--data_dir',
        dest='data_dir',
24
        help='The root directory of dataset',
W
wuyefeilin 已提交
25 26
        type=str)
    parser.add_argument(
27 28 29 30
        '--test_list',
        dest='test_list',
        help='Test list file of dataset',
        type=str)
W
wuyefeilin 已提交
31 32 33
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
34
        help='The directory for saving the inference results',
W
wuyefeilin 已提交
35
        type=str,
36
        default='./output/result')
C
chenguowei01 已提交
37 38 39 40 41 42 43
    parser.add_argument(
        "--image_shape",
        dest="image_shape",
        help="The image shape for net inputs.",
        nargs=2,
        default=[192, 192],
        type=int)
W
wuyefeilin 已提交
44 45 46 47 48 49 50 51 52
    return parser.parse_args()


def mkdir(path):
    sub_dir = osp.dirname(path)
    if not osp.exists(sub_dir):
        os.makedirs(sub_dir)


53 54
def infer(args):
    test_transforms = transforms.Compose(
C
chenguowei01 已提交
55
        [transforms.Resize(args.image_shape),
56 57
         transforms.Normalize()])
    model = models.load_model(args.model_dir)
W
wuyefeilin 已提交
58 59 60 61
    added_saveed_path = osp.join(args.save_dir, 'added')
    mat_saved_path = osp.join(args.save_dir, 'mat')
    scoremap_saved_path = osp.join(args.save_dir, 'scoremap')

62
    with open(args.test_list, 'r') as f:
W
wuyefeilin 已提交
63 64 65 66 67 68
        files = f.readlines()

    for file in tqdm.tqdm(files):
        file = file.strip()
        im_file = osp.join(args.data_dir, file)
        im = cv2.imread(im_file)
69
        result = model.predict(im, transforms=test_transforms)
W
wuyefeilin 已提交
70 71

        # save added image
72
        added_image = utils.visualize(im_file, result, weight=0.6)
W
wuyefeilin 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        added_image_file = osp.join(added_saveed_path, file)
        mkdir(added_image_file)
        cv2.imwrite(added_image_file, added_image)

        # save score map
        score_map = result['score_map'][:, :, 1]
        score_map = (score_map * 255).astype(np.uint8)
        score_map_file = osp.join(scoremap_saved_path, file)
        mkdir(score_map_file)
        cv2.imwrite(score_map_file, score_map)

        # save mat image
        score_map = np.expand_dims(score_map, axis=-1)
        mat_image = np.concatenate([im, score_map], axis=2)
        mat_file = osp.join(mat_saved_path, file)
        ext = osp.splitext(mat_file)[-1]
        mat_file = mat_file.replace(ext, '.png')
        mkdir(mat_file)
        cv2.imwrite(mat_file, mat_image)


if __name__ == '__main__':
    args = parse_args()
96
    infer(args)