diff --git a/pdseg/models/model_builder.py b/pdseg/models/model_builder.py index 67b78c1a01cef4f5741058cf567682d8f6f3768d..f15c82e6694b0e58de662ce567aa80e9ea6a897f 100644 --- a/pdseg/models/model_builder.py +++ b/pdseg/models/model_builder.py @@ -26,7 +26,7 @@ from loss import multi_dice_loss from loss import multi_bce_loss from lovasz_losses import lovasz_hinge from lovasz_losses import lovasz_softmax -from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn +from models.modeling import deeplab, unet, icnet, pspnet, hrnet, fast_scnn,ocnet class ModelPhase(object): @@ -85,6 +85,8 @@ def seg_model(image, class_num): logits = hrnet.hrnet(image, class_num) elif model_name == 'fast_scnn': logits = fast_scnn.fast_scnn(image, class_num) + elif model_name == 'ocnet': + logits = ocnet.ocnet(image, class_num) else: raise Exception( "unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet, fast_scnn"