提交 8abced85 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

update

上级 8652bb60
...@@ -30,12 +30,16 @@ ...@@ -30,12 +30,16 @@
下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。 下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。
## 预训练模型 ## 预训练模型
* [预训练模型下载地址(百度网盘 Password: )]() * [预训练模型下载地址(百度网盘 Password: ri6m )](https://pan.baidu.com/s/1I5fPAyXDfIh9M5POs80ovg)
## 项目使用方法 ## 项目使用方法
### 模型训练 ### 步骤1:生成训练数据
* 目前支持样本分辨率 512,256,注意训练、推理脚本也要做相应的分辨率对应设置。
* 运行脚本:prepropess_data.py (注意脚本内相关参数配置 )
### 步骤2:模型训练
* 根目录下运行命令: python train.py (注意脚本内相关参数配置 ) * 根目录下运行命令: python train.py (注意脚本内相关参数配置 )
### 模型推理 ### 步骤3:模型推理
* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 ) * 根目录下运行命令: python inference.py (注意脚本内相关参数配置 )
...@@ -83,7 +83,7 @@ def inference( img_size, image_path, model_path): ...@@ -83,7 +83,7 @@ def inference( img_size, image_path, model_path):
img_ = cv2.imread(image_path + f_) img_ = cv2.imread(image_path + f_)
img = Image.fromarray(cv2.cvtColor(img_,cv2.COLOR_BGR2RGB)) img = Image.fromarray(cv2.cvtColor(img_,cv2.COLOR_BGR2RGB))
image = img.resize((img_size, img_size), Image.BILINEAR) image = img.resize((img_size, img_size))
img = to_tensor(image) img = to_tensor(image)
img = torch.unsqueeze(img, 0) img = torch.unsqueeze(img, 0)
img = img.cuda() img = img.cuda()
...@@ -93,6 +93,7 @@ def inference( img_size, image_path, model_path): ...@@ -93,6 +93,7 @@ def inference( img_size, image_path, model_path):
print('<{}> image : '.format(idx),np.unique(parsing_)) print('<{}> image : '.format(idx),np.unique(parsing_))
parsing_ = cv2.resize(parsing_,(img_.shape[1],img_.shape[0]),interpolation=cv2.INTER_NEAREST) parsing_ = cv2.resize(parsing_,(img_.shape[1],img_.shape[0]),interpolation=cv2.INTER_NEAREST)
parsing_ = parsing_.astype(np.uint8) parsing_ = parsing_.astype(np.uint8)
vis_im = vis_parsing_maps(img_, parsing_, 0,0,stride=1) vis_im = vis_parsing_maps(img_, parsing_, 0,0,stride=1)
...@@ -104,10 +105,11 @@ def inference( img_size, image_path, model_path): ...@@ -104,10 +105,11 @@ def inference( img_size, image_path, model_path):
cv2.namedWindow("vis_im",0) cv2.namedWindow("vis_im",0)
cv2.imshow("vis_im",vis_im) cv2.imshow("vis_im",vis_im)
if cv2.waitKey(500) == 27: if cv2.waitKey(500) == 27:
break break
if __name__ == "__main__": if __name__ == "__main__":
img_size = 512 img_size = 512 # 推理分辨率设置
model_path = "./model_exp/2021-02-23_22-03-22/fp_latest.pth" model_path = "./weights/fp_512.pth" # 模型路径
image_path = "./images/" image_path = "./images/"
inference(img_size = img_size, image_path=image_path, model_path=model_path) inference(img_size = img_size, image_path=image_path, model_path=model_path)
...@@ -29,8 +29,8 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): ...@@ -29,8 +29,8 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
max_epoch = 1000 max_epoch = 1000
n_classes = 19 n_classes = 19
n_img_per_gpu = 16 n_img_per_gpu = 16
n_workers = 8 n_workers = 12
cropsize = [int(image_size*0.85),int(image_size*0.85)] cropsize = [int(image_size*0.8),int(image_size*0.8)]
# DataLoader 数据迭代器 # DataLoader 数据迭代器
ds = FaceMask(path_data,img_size = image_size, cropsize=cropsize, mode='train') ds = FaceMask(path_data,img_size = image_size, cropsize=cropsize, mode='train')
...@@ -139,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): ...@@ -139,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
torch.save(state, model_exp+'fp_{}_epoch-{}.pth'.format(image_size,epoch)) torch.save(state, model_exp+'fp_{}_epoch-{}.pth'.format(image_size,epoch))
if __name__ == "__main__": if __name__ == "__main__":
image_size = 256 image_size = 512
lr0 = 1e-4 lr0 = 1e-4
model_exp = './model_exp/' model_exp = './model_exp/'
path_data = './CelebAMask-HQ/' path_data = './CelebAMask-HQ/'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册