未验证 提交 30f3cff6 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

fix(segmentation): fix VOC category and add cityscapes (#18)

上级 475b6b9f
# Semantic Segmentation # Semantic Segmentation
本目录包含了采用MegEngine实现的经典[Deeplabv3plus](https://arxiv.org/abs/1802.02611.pdf)网络结构,同时提供了在PASCAL VOC数据集上的完整训练和测试代码。 本目录包含了采用MegEngine实现的经典[Deeplabv3plus](https://arxiv.org/abs/1802.02611.pdf)网络结构,同时提供了在PASCAL VOC和Cityscapes数据集上的完整训练和测试代码。
网络在PASCAL VOC2012验证集的性能和结果如下: 网络在PASCAL VOC2012验证集的性能和结果如下:
...@@ -38,20 +38,25 @@ ...@@ -38,20 +38,25 @@
3、开始训练: 3、开始训练:
`train.py`的命令行参数如下: `train.py`的命令行参数如下:
- `--config`,训练时采用的配置文件,VOC和Cityscapes各一份默认配置;
- `--dataset_dir`,训练时采用的训练集存放的目录; - `--dataset_dir`,训练时采用的训练集存放的目录;
- `--weight_file`,训练时采用的预训练权重; - `--weight_file`,训练时采用的预训练权重;
- `--batch-size`,训练时采用的batch size, 默认8;
- `--ngpus`, 训练时采用的gpu数量,默认8; 当设置为1时,表示单卡训练 - `--ngpus`, 训练时采用的gpu数量,默认8; 当设置为1时,表示单卡训练
- `--resume`, 是否从已训好的模型继续训练; - `--resume`, 是否从已训好的模型继续训练,默认`None`
- `--train_epochs`, 需要训练的epoch数量;
```bash ```bash
python3 train.py --dataset_dir /path/to/VOC2012 \ python3 train.py --config cfg_voc.py \
--dataset_dir /path/to/VOC2012 \
--weight_file /path/to/weights.pkl \ --weight_file /path/to/weights.pkl \
--batch_size 8 \ --ngpus 8
--ngpus 8 \ ```
--train_epochs 50 \
--resume /path/to/model 或在Cityscapes数据集上进行训练:
```bash
python3 train.py --config cfg_cityscapes.py \
--dataset_dir /path/to/Cityscapes \
--weight_file /path/to/weights.pkl \
--ngpus 8
``` ```
## 如何测试 ## 如何测试
...@@ -59,11 +64,13 @@ python3 train.py --dataset_dir /path/to/VOC2012 \ ...@@ -59,11 +64,13 @@ python3 train.py --dataset_dir /path/to/VOC2012 \
模型训练好之后,可以通过如下命令测试模型在VOC2012验证集的性能: 模型训练好之后,可以通过如下命令测试模型在VOC2012验证集的性能:
```bash ```bash
python3 test.py --dataset_dir /path/to/VOC2012 \ python3 test.py --config cfg_voc.py \
--dataset_dir /path/to/VOC2012 \
--model_path /path/to/model.pkl --model_path /path/to/model.pkl
``` ```
`test.py`的命令行参数如下: `test.py`的命令行参数如下:
- `--config`,训练时采用的配置文件,VOC和Cityscapes各一份默认配置;
- `--dataset_dir`,验证时采用的验证集目录; - `--dataset_dir`,验证时采用的验证集目录;
- `--model_path`,载入训练好的模型; - `--model_path`,载入训练好的模型;
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
import os
class Config:
DATASET = "Cityscapes"
BATCH_SIZE = 4
LEARNING_RATE = 0.0065
EPOCHS = 200
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__")))
MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log")
LOG_DIR = MODEL_SAVE_DIR
if not os.path.isdir(MODEL_SAVE_DIR):
os.makedirs(MODEL_SAVE_DIR)
DATA_WORKERS = 4
IGNORE_INDEX = 255
NUM_CLASSES = 19
IMG_HEIGHT = 800
IMG_WIDTH = 800
IMG_MEAN = [103.530, 116.280, 123.675]
IMG_STD = [57.375, 57.120, 58.395]
VAL_HEIGHT = 800
VAL_WIDTH = 800
VAL_BATCHES = 1
VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
VAL_FLIP = False
VAL_SLIP = True
VAL_SAVE = None
cfg = Config()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
import os
class Config:
DATASET = "VOC2012"
BATCH_SIZE = 8
LEARNING_RATE = 0.002
EPOCHS = 100
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__")))
MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log")
LOG_DIR = MODEL_SAVE_DIR
if not os.path.isdir(MODEL_SAVE_DIR):
os.makedirs(MODEL_SAVE_DIR)
DATA_WORKERS = 4
DATA_TYPE = "trainaug"
IGNORE_INDEX = 255
NUM_CLASSES = 21
IMG_HEIGHT = 512
IMG_WIDTH = 512
IMG_MEAN = [103.530, 116.280, 123.675]
IMG_STD = [57.375, 57.120, 58.395]
VAL_HEIGHT = 512
VAL_WIDTH = 512
VAL_BATCHES = 1
VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
VAL_FLIP = False
VAL_SLIP = False
VAL_SAVE = None
cfg = Config()
...@@ -27,11 +27,35 @@ class Config: ...@@ -27,11 +27,35 @@ class Config:
cfg = Config() cfg = Config()
# pre-defined colors for at most 20 categories
class_colors = [
[0, 0, 0], # background
[0, 0, 128],
[0, 128, 0],
[0, 128, 128],
[128, 0, 0],
[128, 0, 128],
[128, 128, 0],
[128, 128, 128],
[0, 0, 64],
[0, 0, 192],
[0, 128, 64],
[0, 128, 192],
[128, 0, 64],
[128, 0, 192],
[128, 128, 64],
[128, 128, 192],
[0, 64, 0],
[0, 64, 128],
[0, 192, 0],
[0, 192, 128],
[128, 64, 0],
]
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--image_path", type=str, default=None, help="inference image") parser.add_argument("-i", "--image_path", type=str, default=None, help="inference image")
parser.add_argument("--model_path", type=str, default=None, help="inference model") parser.add_argument("-m", "--model_path", type=str, default=None, help="inference model")
args = parser.parse_args() args = parser.parse_args()
net = load_model(args.model_path) net = load_model(args.model_path)
...@@ -43,7 +67,6 @@ def main(): ...@@ -43,7 +67,6 @@ def main():
pred = inference(img, net) pred = inference(img, net)
cv2.imwrite("out.jpg", pred) cv2.imwrite("out.jpg", pred)
def load_model(model_path): def load_model(model_path):
model_dict = mge.load(model_path) model_dict = mge.load(model_path)
net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES) net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES)
...@@ -73,7 +96,6 @@ def inference(img, net): ...@@ -73,7 +96,6 @@ def inference(img, net):
pred.astype("uint8"), (oriw, orih), interpolation=cv2.INTER_NEAREST pred.astype("uint8"), (oriw, orih), interpolation=cv2.INTER_NEAREST
) )
class_colors = dataset.PascalVOC.class_colors
out = np.zeros((orih, oriw, 3)) out = np.zeros((orih, oriw, 3))
nids = np.unique(pred) nids = np.unique(pred)
for t in nids: for t in nids:
......
...@@ -20,28 +20,14 @@ import numpy as np ...@@ -20,28 +20,14 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from official.vision.segmentation.deeplabv3plus import DeepLabV3Plus from official.vision.segmentation.deeplabv3plus import DeepLabV3Plus
from official.vision.segmentation.utils import import_config_from_file
class Config:
DATA_WORKERS = 4
NUM_CLASSES = 21
IMG_SIZE = 512
IMG_MEAN = [103.530, 116.280, 123.675]
IMG_STD = [57.375, 57.120, 58.395]
VAL_BATCHES = 1
VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
VAL_FLIP = False
VAL_SLIP = False
VAL_SAVE = None
cfg = Config()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="configuration file"
)
parser.add_argument( parser.add_argument(
"-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012", "-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012",
) )
...@@ -50,7 +36,9 @@ def main(): ...@@ -50,7 +36,9 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
test_loader, test_size = build_dataloader(args.dataset_dir) cfg = import_config_from_file(args.config)
test_loader, test_size = build_dataloader(args.dataset_dir, cfg)
print("number of test images: %d" % (test_size)) print("number of test images: %d" % (test_size))
net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES) net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES)
model_dict = mge.load(args.model_path) model_dict = mge.load(args.model_path)
...@@ -63,13 +51,15 @@ def main(): ...@@ -63,13 +51,15 @@ def main():
for sample_batched in tqdm(test_loader): for sample_batched in tqdm(test_loader):
img = sample_batched[0].squeeze() img = sample_batched[0].squeeze()
label = sample_batched[1].squeeze() label = sample_batched[1].squeeze()
pred = evaluate(net, img) im_info = sample_batched[2]
result_list.append({"pred": pred, "gt": label}) pred = evaluate(net, img, cfg)
result_list.append({"pred": pred, "gt": label, "name":im_info[2]})
if cfg.VAL_SAVE: if cfg.VAL_SAVE:
save_results(result_list, cfg.VAL_SAVE) save_results(result_list, cfg.VAL_SAVE, cfg)
compute_metric(result_list) compute_metric(result_list, cfg)
## inference one image
def pad_image_to_shape(img, shape, border_mode, value): def pad_image_to_shape(img, shape, border_mode, value):
margin = np.zeros(4, np.uint32) margin = np.zeros(4, np.uint32)
pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0
...@@ -86,40 +76,39 @@ def pad_image_to_shape(img, shape, border_mode, value): ...@@ -86,40 +76,39 @@ def pad_image_to_shape(img, shape, border_mode, value):
def eval_single(net, img, is_flip): def eval_single(net, img, is_flip):
@jit.trace(symbolic=True, opt_level=2) @jit.trace(symbolic=True, opt_level=2)
def pred_fun(input_data, net=None): def pred_fun(data, net=None):
net.eval() net.eval()
pred = net(input_data) pred = net(data)
return pred return pred
input_data = mge.tensor() data = mge.tensor()
input_data.set_value(img.transpose(2, 0, 1)[np.newaxis]) data.set_value(img.transpose(2, 0, 1)[np.newaxis])
pred = pred_fun(input_data, net=net) pred = pred_fun(data, net=net)
if is_flip: if is_flip:
img_flip = img[:, ::-1, :] img_flip = img[:, ::-1, :]
input_data.set_value(img_flip.transpose(2, 0, 1)[np.newaxis]) data.set_value(img_flip.transpose(2, 0, 1)[np.newaxis])
pred_flip = pred_fun(input_data, net=net) pred_flip = pred_fun(data, net=net)
pred = (pred + pred_flip[:, :, :, ::-1]) / 2.0 pred = (pred + pred_flip[:, :, :, ::-1]) / 2.0
del pred_flip del pred_flip
pred = pred.numpy().squeeze().transpose(1, 2, 0) pred = pred.numpy().squeeze().transpose(1, 2, 0)
del input_data del data
return pred return pred
def evaluate(net, img): def evaluate(net, img, cfg):
ori_h, ori_w, _ = img.shape ori_h, ori_w, _ = img.shape
pred_all = np.zeros((ori_h, ori_w, cfg.NUM_CLASSES)) pred_all = np.zeros((ori_h, ori_w, cfg.NUM_CLASSES))
for rate in cfg.VAL_MULTISCALE: for rate in cfg.VAL_MULTISCALE:
if cfg.VAL_SLIP: if cfg.VAL_SLIP:
new_h, new_w = int(ori_h*rate), int(ori_w*rate)
val_size = (cfg.VAL_HEIGHT, cfg.VAL_WIDTH)
else:
new_h, new_w = int(cfg.VAL_HEIGHT*rate), int(cfg.VAL_WIDTH*rate)
val_size = (new_h, new_w)
img_scale = cv2.resize( img_scale = cv2.resize(
img, None, fx=rate, fy=rate, interpolation=cv2.INTER_LINEAR img, (new_w, new_h), interpolation=cv2.INTER_LINEAR
) )
val_size = (cfg.IMG_SIZE, cfg.IMG_SIZE)
else:
out_h, out_w = int(cfg.IMG_SIZE * rate), int(cfg.IMG_SIZE * rate)
img_scale = cv2.resize(img, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
val_size = (out_h, out_w)
new_h = img_scale.shape[0]
if (new_h <= val_size[0]) and (new_h <= val_size[1]): if (new_h <= val_size[0]) and (new_h <= val_size[1]):
img_pad, margin = pad_image_to_shape( img_pad, margin = pad_image_to_shape(
img_scale, val_size, cv2.BORDER_CONSTANT, value=0 img_scale, val_size, cv2.BORDER_CONSTANT, value=0
...@@ -133,7 +122,6 @@ def evaluate(net, img): ...@@ -133,7 +122,6 @@ def evaluate(net, img):
else: else:
stride_rate = 2 / 3 stride_rate = 2 / 3
stride = [int(np.ceil(i * stride_rate)) for i in val_size] stride = [int(np.ceil(i * stride_rate)) for i in val_size]
print(img_scale.shape, stride, val_size)
img_pad, margin = pad_image_to_shape( img_pad, margin = pad_image_to_shape(
img_scale, val_size, cv2.BORDER_CONSTANT, value=0 img_scale, val_size, cv2.BORDER_CONSTANT, value=0
) )
...@@ -154,19 +142,10 @@ def evaluate(net, img): ...@@ -154,19 +142,10 @@ def evaluate(net, img):
s_x = e_x - val_size[1] s_x = e_x - val_size[1]
s_y = e_y - val_size[0] s_y = e_y - val_size[0]
img_sub = img_pad[s_y:e_y, s_x:e_x, :] img_sub = img_pad[s_y:e_y, s_x:e_x, :]
timg_pad, tmargin = pad_image_to_shape( tpred = eval_single(net, img_sub, cfg.VAL_FLIP)
img_sub, val_size, cv2.BORDER_CONSTANT, value=0
)
print(tmargin, timg_pad.shape)
tpred = eval_single(net, timg_pad, cfg.VAL_FLIP)
tpred = tpred[
margin[0] : (tpred.shape[0] - margin[1]),
margin[2] : (tpred.shape[1] - margin[3]),
:,
]
count_scale[s_y:e_y, s_x:e_x, :] += 1 count_scale[s_y:e_y, s_x:e_x, :] += 1
pred_scale[s_y:e_y, s_x:e_x, :] += tpred pred_scale[s_y:e_y, s_x:e_x, :] += tpred
pred_scale = pred_scale / count_scale #pred_scale = pred_scale / count_scale
pred = pred_scale[ pred = pred_scale[
margin[0] : (pred_scale.shape[0] - margin[1]), margin[0] : (pred_scale.shape[0] - margin[1]),
margin[2] : (pred_scale.shape[1] - margin[3]), margin[2] : (pred_scale.shape[1] - margin[3]),
...@@ -176,77 +155,98 @@ def evaluate(net, img): ...@@ -176,77 +155,98 @@ def evaluate(net, img):
pred = cv2.resize(pred, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR) pred = cv2.resize(pred, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)
pred_all = pred_all + pred pred_all = pred_all + pred
pred_all = pred_all / len(cfg.VAL_MULTISCALE) #pred_all = pred_all / len(cfg.VAL_MULTISCALE)
result = np.argmax(pred_all, axis=2).astype(np.uint8) result = np.argmax(pred_all, axis=2).astype(np.uint8)
return result return result
def save_results(result_list, save_dir): def save_results(result_list, save_dir, cfg):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
for idx, sample in enumerate(result_list): for idx, sample in enumerate(result_list):
file_path = os.path.join(save_dir, "%d.png" % idx) if cfg.DATASET == "Cityscapes":
name = sample["name"].split('/')[-1][:-4]
else:
name = sample["name"]
file_path = os.path.join(save_dir, "%s.png"%name)
cv2.imwrite(file_path, sample["pred"]) cv2.imwrite(file_path, sample["pred"])
file_path = os.path.join(save_dir, "%d.gt.png" % idx) file_path = os.path.join(save_dir, "%s.gt.png"%name)
cv2.imwrite(file_path, sample["gt"]) cv2.imwrite(file_path, sample["gt"])
# voc cityscapes metric
def compute_metric(result_list, cfg):
class_num = cfg.NUM_CLASSES
hist = np.zeros((class_num, class_num))
correct = 0
labeled = 0
count = 0
for idx in range(len(result_list)):
pred = result_list[idx]['pred']
gt = result_list[idx]['gt']
assert(pred.shape == gt.shape)
k = (gt>=0) & (gt<class_num)
labeled += np.sum(k)
correct += np.sum((pred[k]==gt[k]))
hist += np.bincount(class_num * gt[k].astype(int) + pred[k].astype(int), minlength=class_num**2).reshape(class_num, class_num)
count += 1
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
mean_IU = np.nanmean(iu)
mean_IU_no_back = np.nanmean(iu[1:])
freq = hist.sum(1) / hist.sum()
freq_IU = (iu[freq > 0] * freq[freq >0]).sum()
mean_pixel_acc = correct / labeled
if cfg.DATASET == "VOC2012":
class_names = ("background", ) + dataset.PascalVOC.class_names
elif cfg.DATASET == "Cityscapes":
class_names = dataset.Cityscapes.class_names
else:
raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
def compute_metric(result_list): n = iu.size
""" lines = []
modified from https://github.com/YudeWang/deeplabv3plus-pytorch for i in range(n):
""" if class_names is None:
# pylint: disable=redefined-outer-name cls = 'Class %d:' % (i+1)
TP, P, T = [], [], []
for i in range(cfg.NUM_CLASSES):
TP.append(mp.Value("i", 0, lock=True))
P.append(mp.Value("i", 0, lock=True))
T.append(mp.Value("i", 0, lock=True))
def compare(start, step, TP, P, T):
for idx in tqdm(range(start, len(result_list), step)):
pred = result_list[idx]["pred"]
gt = result_list[idx]["gt"]
cal = gt < 255
mask = (pred == gt) * cal
for i in range(cfg.NUM_CLASSES):
P[i].acquire()
P[i].value += np.sum((pred == i) * cal)
P[i].release()
T[i].acquire()
T[i].value += np.sum((gt == i) * cal)
T[i].release()
TP[i].acquire()
TP[i].value += np.sum((gt == i) * mask)
TP[i].release()
p_list = []
for i in range(8):
p = mp.Process(target=compare, args=(i, 8, TP, P, T))
p.start()
p_list.append(p)
for p in p_list:
p.join()
class_names = dataset.PascalVOC.class_names
IoU = []
for i in range(cfg.NUM_CLASSES):
IoU.append(TP[i].value / (T[i].value + P[i].value - TP[i].value + 1e-10))
for i in range(cfg.NUM_CLASSES):
if i == 0:
print("%11s:%7.3f%%" % ("backbound", IoU[i] * 100), end="\t")
else: else:
if i % 2 != 1: cls = '%d %s' % (i+1, class_names[i])
print("%11s:%7.3f%%" % (class_names[i - 1], IoU[i] * 100), end="\t") lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100))
lines.append('---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % ('mean_IU', mean_IU * 100,'mean_pixel_ACC',mean_pixel_acc*100))
line = "\n".join(lines)
print(line)
return mean_IU
class EvalPascalVOC(dataset.PascalVOC):
def _trans_mask(self, mask):
label = np.ones(mask.shape[:2]) * 255
class_colors = self.class_colors.copy()
class_colors.insert(0, [0,0,0])
for i in range(len(class_colors)):
b, g, r = class_colors[i]
label[
(mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
] = i
return label.astype(np.uint8)
def build_dataloader(dataset_dir, cfg):
if cfg.DATASET == "VOC2012":
val_dataset = EvalPascalVOC(
dataset_dir,
"val",
order=["image", "mask", "info"]
)
elif cfg.DATASET == "Cityscapes":
val_dataset = dataset.Cityscapes(
dataset_dir,
"val",
mode='gtFine',
order=["image", "mask", "info"]
)
else: else:
print("%11s:%7.3f%%" % (class_names[i - 1], IoU[i] * 100)) raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
miou = np.mean(np.array(IoU))
print("\n======================================================")
print("%11s:%7.3f%%" % ("mIoU", miou * 100))
return miou
def build_dataloader(dataset_dir):
val_dataset = dataset.PascalVOC(dataset_dir, "val", order=["image", "mask"])
val_sampler = data.SequentialSampler(val_dataset, cfg.VAL_BATCHES) val_sampler = data.SequentialSampler(val_dataset, cfg.VAL_BATCHES)
val_dataloader = data.DataLoader( val_dataloader = data.DataLoader(
val_dataset, val_dataset,
......
...@@ -23,32 +23,16 @@ from official.vision.segmentation.deeplabv3plus import ( ...@@ -23,32 +23,16 @@ from official.vision.segmentation.deeplabv3plus import (
DeepLabV3Plus, DeepLabV3Plus,
softmax_cross_entropy, softmax_cross_entropy,
) )
from official.vision.segmentation.utils import import_config_from_file
logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)
class Config:
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__")))
MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log")
LOG_DIR = MODEL_SAVE_DIR
if not os.path.isdir(MODEL_SAVE_DIR):
os.makedirs(MODEL_SAVE_DIR)
DATA_WORKERS = 4
DATA_TYPE = "trainaug"
IGNORE_INDEX = 255
NUM_CLASSES = 21
IMG_SIZE = 512
IMG_MEAN = [103.530, 116.280, 123.675]
IMG_STD = [57.375, 57.120, 58.395]
cfg = Config()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str, required=True, help="configuration file"
)
parser.add_argument( parser.add_argument(
"-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012", "-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012",
) )
...@@ -58,19 +42,6 @@ def main(): ...@@ -58,19 +42,6 @@ def main():
parser.add_argument( parser.add_argument(
"-n", "--ngpus", type=int, default=8, help="batchsize for training" "-n", "--ngpus", type=int, default=8, help="batchsize for training"
) )
parser.add_argument(
"-b", "--batch_size", type=int, default=8, help="batchsize for training"
)
parser.add_argument(
"-lr",
"--base_lr",
type=float,
default=0.002,
help="base learning rate for training",
)
parser.add_argument(
"-e", "--train_epochs", type=int, default=100, help="epochs for training"
)
parser.add_argument( parser.add_argument(
"-r", "--resume", type=str, default=None, help="resume model file" "-r", "--resume", type=str, default=None, help="resume model file"
) )
...@@ -92,6 +63,8 @@ def main(): ...@@ -92,6 +63,8 @@ def main():
def worker(rank, world_size, args): def worker(rank, world_size, args):
cfg = import_config_from_file(args.config)
if world_size > 1: if world_size > 1:
dist.init_process_group( dist.init_process_group(
master_ip="localhost", master_ip="localhost",
...@@ -103,11 +76,11 @@ def worker(rank, world_size, args): ...@@ -103,11 +76,11 @@ def worker(rank, world_size, args):
logger.info("Init process group done") logger.info("Init process group done")
logger.info("Prepare dataset") logger.info("Prepare dataset")
train_loader, epoch_size = build_dataloader(args.batch_size, args.dataset_dir) train_loader, epoch_size = build_dataloader(cfg.BATCH_SIZE, args.dataset_dir, cfg)
batch_iter = epoch_size // (args.batch_size * world_size) batch_iter = epoch_size // (cfg.BATCH_SIZE * world_size)
net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES, pretrained=args.weight_file) net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES, pretrained=args.weight_file)
base_lr = args.base_lr * world_size base_lr = cfg.LEARNING_RATE * world_size
optimizer = optim.SGD( optimizer = optim.SGD(
net.parameters(requires_grad=True), net.parameters(requires_grad=True),
lr=base_lr, lr=base_lr,
...@@ -116,15 +89,15 @@ def worker(rank, world_size, args): ...@@ -116,15 +89,15 @@ def worker(rank, world_size, args):
) )
@jit.trace(symbolic=True, opt_level=2) @jit.trace(symbolic=True, opt_level=2)
def train_func(input_data, label, net=None, optimizer=None): def train_func(data, label, net=None, optimizer=None):
net.train() net.train()
pred = net(input_data) pred = net(data)
loss = softmax_cross_entropy(pred, label, ignore_index=cfg.IGNORE_INDEX) loss = softmax_cross_entropy(pred, label, ignore_index=cfg.IGNORE_INDEX)
optimizer.backward(loss) optimizer.backward(loss)
return pred, loss return pred, loss
begin_epoch = 0 begin_epoch = 0
end_epoch = args.train_epochs end_epoch = cfg.EPOCHS
if args.resume is not None: if args.resume is not None:
pretrained = mge.load(args.resume) pretrained = mge.load(args.resume)
begin_epoch = pretrained["epoch"] + 1 begin_epoch = pretrained["epoch"] + 1
...@@ -135,11 +108,11 @@ def worker(rank, world_size, args): ...@@ -135,11 +108,11 @@ def worker(rank, world_size, args):
max_itr = end_epoch * batch_iter max_itr = end_epoch * batch_iter
image = mge.tensor( image = mge.tensor(
np.zeros([args.batch_size, 3, cfg.IMG_SIZE, cfg.IMG_SIZE]).astype(np.float32), np.zeros([cfg.BATCH_SIZE, 3, cfg.IMG_HEIGHT, cfg.IMG_WIDTH]).astype(np.float32),
dtype="float32", dtype="float32",
) )
label = mge.tensor( label = mge.tensor(
np.zeros([args.batch_size, cfg.IMG_SIZE, cfg.IMG_SIZE]).astype(np.int32), np.zeros([cfg.BATCH_SIZE, cfg.IMG_HEIGHT, cfg.IMG_WIDTH]).astype(np.int32),
dtype="int32", dtype="int32",
) )
exp_name = os.path.abspath(os.path.dirname(__file__)).split("/")[-1] exp_name = os.path.abspath(os.path.dirname(__file__)).split("/")[-1]
...@@ -184,10 +157,22 @@ def worker(rank, world_size, args): ...@@ -184,10 +157,22 @@ def worker(rank, world_size, args):
logger.info("save epoch%d", epoch) logger.info("save epoch%d", epoch)
def build_dataloader(batch_size, dataset_dir): def build_dataloader(batch_size, dataset_dir, cfg):
if cfg.DATASET == "VOC2012":
train_dataset = dataset.PascalVOC( train_dataset = dataset.PascalVOC(
dataset_dir, cfg.DATA_TYPE, order=["image", "mask"] dataset_dir,
cfg.DATA_TYPE,
order=["image", "mask"]
) )
elif cfg.DATASET == "Cityscapes":
train_dataset = dataset.Cityscapes(
dataset_dir,
"train",
mode='gtFine',
order=["image", "mask"]
)
else:
raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
train_sampler = data.RandomSampler(train_dataset, batch_size, drop_last=True) train_sampler = data.RandomSampler(train_dataset, batch_size, drop_last=True)
train_dataloader = data.DataLoader( train_dataloader = data.DataLoader(
train_dataset, train_dataset,
...@@ -197,7 +182,7 @@ def build_dataloader(batch_size, dataset_dir): ...@@ -197,7 +182,7 @@ def build_dataloader(batch_size, dataset_dir):
T.RandomHorizontalFlip(0.5), T.RandomHorizontalFlip(0.5),
T.RandomResize(scale_range=(0.5, 2)), T.RandomResize(scale_range=(0.5, 2)),
T.RandomCrop( T.RandomCrop(
output_size=(cfg.IMG_SIZE, cfg.IMG_SIZE), output_size=(cfg.IMG_HEIGHT, cfg.IMG_WIDTH),
padding_value=[0, 0, 0], padding_value=[0, 0, 0],
padding_maskvalue=255, padding_maskvalue=255,
), ),
......
import importlib.util
import os
def import_config_from_file(cfg_file):
assert os.path.exists(cfg_file), "config file {} not exists".format(cfg_file)
spec = importlib.util.spec_from_file_location("config", cfg_file)
cfg_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cfg_module)
return cfg_module.cfg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册