diff --git a/applications/wyw2s_local_app.py b/applications/wyw2s_local_app.py index 91bd029a20b6a5f3d5f409af9b35c2543b17340a..3649e19f721bd00be1283fae3055ae0f06ebce61 100644 --- a/applications/wyw2s_local_app.py +++ b/applications/wyw2s_local_app.py @@ -60,7 +60,7 @@ def main_wyw2s(video_path,cfg_file): facebank_path = config["facebank_path"], threshold = float(config["face_verify_threshold"])) - face_multitask_model = FaceMuitiTask_Model(model_path = config["face_multitask_model_path"]) + face_multitask_model = FaceMuitiTask_Model(model_path = config["face_multitask_model_path"], model_arch = config["face_multitask_model_arch"]) face_euler_model = FaceAngle_Model(model_path = config["face_euler_model_path"]) diff --git a/components/face_multi_task/face_multi_task_component.py b/components/face_multi_task/face_multi_task_component.py index 0d2b655d389128f41a9537b8fd4e024468ae60cd..141b1d187a5be2d8aef9c874f891d0a45cb2d930 100644 --- a/components/face_multi_task/face_multi_task_component.py +++ b/components/face_multi_task/face_multi_task_component.py @@ -8,7 +8,7 @@ import torch import cv2 import torch.nn.functional as F -from face_multi_task.network.resnet import resnet50 +from face_multi_task.network.resnet import resnet50,resnet34,resnet18 from face_multi_task.utils.common_utils import * import numpy as np @@ -17,6 +17,7 @@ class FaceMuitiTask_Model(object): model_path = './components/face_multi_task/weights_multask/resnet_50_imgsize-256-20210411.pth', img_size=256, num_classes = 196,# 人脸关键点,年龄,性别 + model_arch = "resnet50",# 模型结构 ): use_cuda = torch.cuda.is_available() @@ -24,7 +25,13 @@ class FaceMuitiTask_Model(object): self.device = torch.device("cuda:0" if use_cuda else "cpu") # 可选的设备类型及序号 self.img_size = img_size #----------------------------------------------------------------------- - face_multi_model = resnet50(landmarks_num=num_classes, img_size=img_size) + if model_arch == "resnet50": + face_multi_model = resnet50(landmarks_num=num_classes, img_size=img_size) + elif model_arch == "resnet34": + face_multi_model = resnet34(landmarks_num=num_classes, img_size=img_size) + elif model_arch == "resnet18": + face_multi_model = resnet18(landmarks_num=num_classes, img_size=img_size) + chkpt = torch.load(model_path, map_location=lambda storage, loc: storage) face_multi_model.load_state_dict(chkpt) face_multi_model.eval() diff --git a/lib/wyw2s_lib/cfg/wyw2s.cfg b/lib/wyw2s_lib/cfg/wyw2s.cfg index ce7b5ef6c7c91959cc0055d698607f0f1ff6f326..1615eaf21d172043f8cde441e7d6b967c5aa67b9 100644 --- a/lib/wyw2s_lib/cfg/wyw2s.cfg +++ b/lib/wyw2s_lib/cfg/wyw2s.cfg @@ -11,6 +11,10 @@ face_verify_backbone_path=./wyw2s_models/face_verify-model_ir_se-50.pth facebank_path=./wyw2s_models/facebank face_verify_threshold=1.2 -face_multitask_model_path=./wyw2s_models/face_multitask-resnet_50_imgsize-256-20210411.pth +#face_multitask_model_path=./wyw2s_models/face_multitask-resnet_50_imgsize-256-20210411.pth +#face_multitask_model_arch=resnet50 + +face_multitask_model_path=./wyw2s_models/face_multitask-resnet_34_imgsize-256-20210423.pth +face_multitask_model_arch=resnet34 face_euler_model_path=./wyw2s_models/euler_angle-resnet_18_imgsize_256.pth