提交 a7b89447 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

add models inference

上级 038ff8f2
......@@ -17,21 +17,26 @@ from datetime import datetime
import cv2
import torch.nn.functional as F
from models.resnet import resnet50, resnet34
from models.resnet import resnet18,resnet34,resnet50,resnet101
from models.squeezenet import squeezenet1_1,squeezenet1_0
from models.shufflenetv2 import ShuffleNetV2
from models.shufflenet import ShuffleNet
from models.mobilenetv2 import MobileNetV2
from torchvision.models import shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0
from models.rexnetv1 import ReXNetV1
from utils.common_utils import *
import copy
from hand_data_iter.datasets import draw_bd_handpose
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=' Project Hand Pose Inference')
parser.add_argument('--model_path', type=str, default = './model_exp/2021-02-21_23-25-14/model_epoch-2.pth',
parser.add_argument('--model_path', type=str, default = './weights/ReXNetV1-size-256-wingloss102-0.122.pth',
help = 'model_path') # 模型路径
parser.add_argument('--model', type=str, default = 'resnet_50',
help = 'model : resnet_x,squeezenet_x') # 模型类型
parser.add_argument('--model', type=str, default = 'ReXNetV1',
help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2
shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0,ReXNetV1''') # 模型类型
parser.add_argument('--num_classes', type=int , default = 42,
help = 'num_classes') # 手部21关键点, (x,y)*2 = 42
parser.add_argument('--GPUS', type=str, default = '0',
......@@ -63,12 +68,30 @@ if __name__ == "__main__":
if ops.model == 'resnet_50':
model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_18':
model_ = resnet18(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_34':
model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_101':
model_ = resnet101(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == "squeezenet1_0":
model_ = squeezenet1_0(num_classes=ops.num_classes)
elif ops.model == "squeezenet1_1":
model_ = squeezenet1_1(num_classes=ops.num_classes)
elif ops.model == "shufflenetv2":
model_ = ShuffleNetV2(ratio=1., num_classes=ops.num_classes)
elif ops.model == "shufflenet_v2_x1_5":
model_ = shufflenet_v2_x1_5(pretrained=False,num_classes=ops.num_classes)
elif ops.model == "shufflenet_v2_x1_0":
model_ = shufflenet_v2_x1_0(pretrained=False,num_classes=ops.num_classes)
elif ops.model == "shufflenet_v2_x2_0":
model_ = shufflenet_v2_x2_0(pretrained=False,num_classes=ops.num_classes)
elif ops.model == "shufflenet":
model_ = ShuffleNet(num_blocks = [2,4,2], num_classes=ops.num_classes, groups=3)
elif ops.model == "mobilenetv2":
model_ = MobileNetV2(num_classes=ops.num_classes)
elif ops.model == "ReXNetV1":
model_ = ReXNetV1(num_classes=ops.num_classes)
use_cuda = torch.cuda.is_available()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册