From 25968ad146d70195875864b3e82f471e1f33eef4 Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Fri, 23 Apr 2021 11:48:03 +0800 Subject: [PATCH] add model arh --- applications/wyw2s_local_app.py | 2 +- .../face_multi_task/face_multi_task_component.py | 11 +++++++++-- lib/wyw2s_lib/cfg/wyw2s.cfg | 6 +++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/applications/wyw2s_local_app.py b/applications/wyw2s_local_app.py index 91bd029..3649e19 100644 --- a/applications/wyw2s_local_app.py +++ b/applications/wyw2s_local_app.py @@ -60,7 +60,7 @@ def main_wyw2s(video_path,cfg_file): facebank_path = config["facebank_path"], threshold = float(config["face_verify_threshold"])) - face_multitask_model = FaceMuitiTask_Model(model_path = config["face_multitask_model_path"]) + face_multitask_model = FaceMuitiTask_Model(model_path = config["face_multitask_model_path"], model_arch = config["face_multitask_model_arch"]) face_euler_model = FaceAngle_Model(model_path = config["face_euler_model_path"]) diff --git a/components/face_multi_task/face_multi_task_component.py b/components/face_multi_task/face_multi_task_component.py index 0d2b655..141b1d1 100644 --- a/components/face_multi_task/face_multi_task_component.py +++ b/components/face_multi_task/face_multi_task_component.py @@ -8,7 +8,7 @@ import torch import cv2 import torch.nn.functional as F -from face_multi_task.network.resnet import resnet50 +from face_multi_task.network.resnet import resnet50,resnet34,resnet18 from face_multi_task.utils.common_utils import * import numpy as np @@ -17,6 +17,7 @@ class FaceMuitiTask_Model(object): model_path = './components/face_multi_task/weights_multask/resnet_50_imgsize-256-20210411.pth', img_size=256, num_classes = 196,# 人脸关键点,年龄,性别 + model_arch = "resnet50",# 模型结构 ): use_cuda = torch.cuda.is_available() @@ -24,7 +25,13 @@ class FaceMuitiTask_Model(object): self.device = torch.device("cuda:0" if use_cuda else "cpu") # 可选的设备类型及序号 self.img_size = img_size #----------------------------------------------------------------------- - face_multi_model = resnet50(landmarks_num=num_classes, img_size=img_size) + if model_arch == "resnet50": + face_multi_model = resnet50(landmarks_num=num_classes, img_size=img_size) + elif model_arch == "resnet34": + face_multi_model = resnet34(landmarks_num=num_classes, img_size=img_size) + elif model_arch == "resnet18": + face_multi_model = resnet18(landmarks_num=num_classes, img_size=img_size) + chkpt = torch.load(model_path, map_location=lambda storage, loc: storage) face_multi_model.load_state_dict(chkpt) face_multi_model.eval() diff --git a/lib/wyw2s_lib/cfg/wyw2s.cfg b/lib/wyw2s_lib/cfg/wyw2s.cfg index ce7b5ef..1615eaf 100644 --- a/lib/wyw2s_lib/cfg/wyw2s.cfg +++ b/lib/wyw2s_lib/cfg/wyw2s.cfg @@ -11,6 +11,10 @@ face_verify_backbone_path=./wyw2s_models/face_verify-model_ir_se-50.pth facebank_path=./wyw2s_models/facebank face_verify_threshold=1.2 -face_multitask_model_path=./wyw2s_models/face_multitask-resnet_50_imgsize-256-20210411.pth +#face_multitask_model_path=./wyw2s_models/face_multitask-resnet_50_imgsize-256-20210411.pth +#face_multitask_model_arch=resnet50 + +face_multitask_model_path=./wyw2s_models/face_multitask-resnet_34_imgsize-256-20210423.pth +face_multitask_model_arch=resnet34 face_euler_model_path=./wyw2s_models/euler_angle-resnet_18_imgsize_256.pth -- GitLab