未验证 提交 aa7acdfb 编写于 作者: T tianlanshidai 提交者: GitHub

Update model_builder.py

上级 ae1a4aa1
...@@ -26,7 +26,7 @@ from loss import multi_dice_loss ...@@ -26,7 +26,7 @@ from loss import multi_dice_loss
from loss import multi_bce_loss from loss import multi_bce_loss
from lovasz_losses import lovasz_hinge from lovasz_losses import lovasz_hinge
from lovasz_losses import lovasz_softmax from lovasz_losses import lovasz_softmax
from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn,ocnet
class ModelPhase(object): class ModelPhase(object):
...@@ -85,6 +85,8 @@ def seg_model(image, class_num): ...@@ -85,6 +85,8 @@ def seg_model(image, class_num):
logits = hrnet.hrnet(image, class_num) logits = hrnet.hrnet(image, class_num)
elif model_name == 'fast_scnn': elif model_name == 'fast_scnn':
logits = fast_scnn.fast_scnn(image, class_num) logits = fast_scnn.fast_scnn(image, class_num)
elif model_name == 'ocnet':
logits = ocnet.ocnet(image, class_num)
else: else:
raise Exception( raise Exception(
"unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet, fast_scnn" "unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet, fast_scnn"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册