提交 ae0b221d 编写于 作者: C chenguowei01

add hrnet

上级 8949ec49
...@@ -24,7 +24,7 @@ import tqdm ...@@ -24,7 +24,7 @@ import tqdm
from datasets import OpticDiscSeg, Cityscapes from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models from models import MODELS
import utils import utils
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
...@@ -37,7 +37,12 @@ def parse_args(): ...@@ -37,7 +37,12 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help="Model type for traing, which is one of ('UNet')", help=
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
...@@ -146,8 +151,11 @@ def main(args): ...@@ -146,8 +151,11 @@ def main(args):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test') test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet': if args.model_name not in MODELS:
model = models.UNet(num_classes=test_dataset.num_classes) raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=test_dataset.num_classes)
infer( infer(
model, model,
......
...@@ -13,3 +13,28 @@ ...@@ -13,3 +13,28 @@
# limitations under the License. # limitations under the License.
from .unet import UNet from .unet import UNet
from .hrnet import *
MODELS = {
"UNet": UNet,
"HRNet_W18_Small_V1": HRNet_W18_Small_V1,
"HRNet_W18_Small_V2": HRNet_W18_Small_V2,
"HRNet_W18": HRNet_W18,
"HRNet_W30": HRNet_W30,
"HRNet_W32": HRNet_W32,
"HRNet_W40": HRNet_W40,
"HRNet_W44": HRNet_W44,
"HRNet_W48": HRNet_W48,
"HRNet_W60": HRNet_W48,
"HRNet_W64": HRNet_W64,
"SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
"SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
"SE_HRNet_W18": SE_HRNet_W18,
"SE_HRNet_W30": SE_HRNet_W30,
"SE_HRNet_W32": SE_HRNet_W30,
"SE_HRNet_W40": SE_HRNet_W40,
"SE_HRNet_W44": SE_HRNet_W44,
"SE_HRNet_W48": SE_HRNet_W48,
"SE_HRNet_W60": SE_HRNet_W60,
"SE_HRNet_W64": SE_HRNet_W64
}
此差异已折叠。
...@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler ...@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models from models import MODELS
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import load_pretrained_model from utils import load_pretrained_model
...@@ -38,7 +38,12 @@ def parse_args(): ...@@ -38,7 +38,12 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help="Model type for traing, which is one of ('UNet')", help=
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
...@@ -181,7 +186,6 @@ def train(model, ...@@ -181,7 +186,6 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch) total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0 num_steps = 0
best_mean_iou = -1.0 best_mean_iou = -1.0
best_model_epoch = 1
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = data[0] images = data[0]
...@@ -286,9 +290,11 @@ def main(args): ...@@ -286,9 +290,11 @@ def main(args):
T.Normalize()]) T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet': if args.model_name not in MODELS:
model = models.UNet( raise Exception(
num_classes=train_dataset.num_classes, ignore_index=255) '--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer # Creat optimizer
# todo, may less one than len(loader) # todo, may less one than len(loader)
......
...@@ -25,7 +25,7 @@ from paddle.fluid.dataloader import BatchSampler ...@@ -25,7 +25,7 @@ from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models from models import MODELS
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import ConfusionMatrix from utils import ConfusionMatrix
...@@ -39,7 +39,12 @@ def parse_args(): ...@@ -39,7 +39,12 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help="Model type for evaluation, which is one of ('UNet')", help=
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
...@@ -153,8 +158,11 @@ def main(args): ...@@ -153,8 +158,11 @@ def main(args):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet': if args.model_name not in MODELS:
model = models.UNet(num_classes=eval_dataset.num_classes) raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=eval_dataset.num_classes)
evaluate( evaluate(
model, model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册