提交 ae0b221d 编写于 作者: C chenguowei01

add hrnet

上级 8949ec49
......@@ -24,7 +24,7 @@ import tqdm
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
from models import MODELS
import utils
import utils.logging as logging
from utils import get_environ_info
......@@ -37,7 +37,12 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
help=
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str,
default='UNet')
......@@ -146,8 +151,11 @@ def main(args):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet':
model = models.UNet(num_classes=test_dataset.num_classes)
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=test_dataset.num_classes)
infer(
model,
......
......@@ -13,3 +13,28 @@
# limitations under the License.
from .unet import UNet
from .hrnet import *
MODELS = {
"UNet": UNet,
"HRNet_W18_Small_V1": HRNet_W18_Small_V1,
"HRNet_W18_Small_V2": HRNet_W18_Small_V2,
"HRNet_W18": HRNet_W18,
"HRNet_W30": HRNet_W30,
"HRNet_W32": HRNet_W32,
"HRNet_W40": HRNet_W40,
"HRNet_W44": HRNet_W44,
"HRNet_W48": HRNet_W48,
"HRNet_W60": HRNet_W48,
"HRNet_W64": HRNet_W64,
"SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
"SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
"SE_HRNet_W18": SE_HRNet_W18,
"SE_HRNet_W30": SE_HRNet_W30,
"SE_HRNet_W32": SE_HRNet_W30,
"SE_HRNet_W40": SE_HRNet_W40,
"SE_HRNet_W44": SE_HRNet_W44,
"SE_HRNet_W48": SE_HRNet_W48,
"SE_HRNet_W60": SE_HRNet_W60,
"SE_HRNet_W64": SE_HRNet_W64
}
此差异已折叠。
......@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
from models import MODELS
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
......@@ -38,7 +38,12 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
help=
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str,
default='UNet')
......@@ -181,7 +186,6 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0
best_mean_iou = -1.0
best_model_epoch = 1
for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader):
images = data[0]
......@@ -286,9 +290,11 @@ def main(args):
T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(
num_classes=train_dataset.num_classes, ignore_index=255)
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer
# todo, may less one than len(loader)
......
......@@ -25,7 +25,7 @@ from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
from models import MODELS
import utils.logging as logging
from utils import get_environ_info
from utils import ConfusionMatrix
......@@ -39,7 +39,12 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for evaluation, which is one of ('UNet')",
help=
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str,
default='UNet')
......@@ -153,8 +158,11 @@ def main(args):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(num_classes=eval_dataset.num_classes)
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=eval_dataset.num_classes)
evaluate(
model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册