diff --git a/inference.py b/inference.py index 88bfe01d917fbce1eff3bc2bb9d830e129a3bc24..93428d4421508f9a505e3ca94133191055a0627a 100644 --- a/inference.py +++ b/inference.py @@ -17,21 +17,26 @@ from datetime import datetime import cv2 import torch.nn.functional as F -from models.resnet import resnet50, resnet34 +from models.resnet import resnet18,resnet34,resnet50,resnet101 from models.squeezenet import squeezenet1_1,squeezenet1_0 +from models.shufflenetv2 import ShuffleNetV2 +from models.shufflenet import ShuffleNet +from models.mobilenetv2 import MobileNetV2 +from torchvision.models import shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0 +from models.rexnetv1 import ReXNetV1 + from utils.common_utils import * import copy from hand_data_iter.datasets import draw_bd_handpose - if __name__ == "__main__": parser = argparse.ArgumentParser(description=' Project Hand Pose Inference') - - parser.add_argument('--model_path', type=str, default = './model_exp/2021-02-21_23-25-14/model_epoch-2.pth', + parser.add_argument('--model_path', type=str, default = './weights/ReXNetV1-size-256-wingloss102-0.122.pth', help = 'model_path') # 模型路径 - parser.add_argument('--model', type=str, default = 'resnet_50', - help = 'model : resnet_x,squeezenet_x') # 模型类型 + parser.add_argument('--model', type=str, default = 'ReXNetV1', + help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2 + shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0,ReXNetV1''') # 模型类型 parser.add_argument('--num_classes', type=int , default = 42, help = 'num_classes') # 手部21关键点, (x,y)*2 = 42 parser.add_argument('--GPUS', type=str, default = '0', @@ -63,12 +68,30 @@ if __name__ == "__main__": if ops.model == 'resnet_50': model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_18': + model_ = resnet18(num_classes = ops.num_classes,img_size=ops.img_size[0]) elif ops.model == 'resnet_34': model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_101': + model_ = resnet101(num_classes = ops.num_classes,img_size=ops.img_size[0]) elif ops.model == "squeezenet1_0": model_ = squeezenet1_0(num_classes=ops.num_classes) elif ops.model == "squeezenet1_1": model_ = squeezenet1_1(num_classes=ops.num_classes) + elif ops.model == "shufflenetv2": + model_ = ShuffleNetV2(ratio=1., num_classes=ops.num_classes) + elif ops.model == "shufflenet_v2_x1_5": + model_ = shufflenet_v2_x1_5(pretrained=False,num_classes=ops.num_classes) + elif ops.model == "shufflenet_v2_x1_0": + model_ = shufflenet_v2_x1_0(pretrained=False,num_classes=ops.num_classes) + elif ops.model == "shufflenet_v2_x2_0": + model_ = shufflenet_v2_x2_0(pretrained=False,num_classes=ops.num_classes) + elif ops.model == "shufflenet": + model_ = ShuffleNet(num_blocks = [2,4,2], num_classes=ops.num_classes, groups=3) + elif ops.model == "mobilenetv2": + model_ = MobileNetV2(num_classes=ops.num_classes) + elif ops.model == "ReXNetV1": + model_ = ReXNetV1(num_classes=ops.num_classes) use_cuda = torch.cuda.is_available()