提交 8abced85 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

update

上级 8652bb60
......@@ -30,12 +30,16 @@
下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。
## 预训练模型
* [预训练模型下载地址(百度网盘 Password: )]()
* [预训练模型下载地址(百度网盘 Password: ri6m )](https://pan.baidu.com/s/1I5fPAyXDfIh9M5POs80ovg)
## 项目使用方法
### 模型训练
### 步骤1:生成训练数据
* 目前支持样本分辨率 512,256,注意训练、推理脚本也要做相应的分辨率对应设置。
* 运行脚本:prepropess_data.py (注意脚本内相关参数配置 )
### 步骤2:模型训练
* 根目录下运行命令: python train.py (注意脚本内相关参数配置 )
### 模型推理
### 步骤3:模型推理
* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 )
......@@ -83,7 +83,7 @@ def inference( img_size, image_path, model_path):
img_ = cv2.imread(image_path + f_)
img = Image.fromarray(cv2.cvtColor(img_,cv2.COLOR_BGR2RGB))
image = img.resize((img_size, img_size), Image.BILINEAR)
image = img.resize((img_size, img_size))
img = to_tensor(image)
img = torch.unsqueeze(img, 0)
img = img.cuda()
......@@ -93,6 +93,7 @@ def inference( img_size, image_path, model_path):
print('<{}> image : '.format(idx),np.unique(parsing_))
parsing_ = cv2.resize(parsing_,(img_.shape[1],img_.shape[0]),interpolation=cv2.INTER_NEAREST)
parsing_ = parsing_.astype(np.uint8)
vis_im = vis_parsing_maps(img_, parsing_, 0,0,stride=1)
......@@ -104,10 +105,11 @@ def inference( img_size, image_path, model_path):
cv2.namedWindow("vis_im",0)
cv2.imshow("vis_im",vis_im)
if cv2.waitKey(500) == 27:
break
if __name__ == "__main__":
img_size = 512
model_path = "./model_exp/2021-02-23_22-03-22/fp_latest.pth"
img_size = 512 # 推理分辨率设置
model_path = "./weights/fp_512.pth" # 模型路径
image_path = "./images/"
inference(img_size = img_size, image_path=image_path, model_path=model_path)
......@@ -29,8 +29,8 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
max_epoch = 1000
n_classes = 19
n_img_per_gpu = 16
n_workers = 8
cropsize = [int(image_size*0.85),int(image_size*0.85)]
n_workers = 12
cropsize = [int(image_size*0.8),int(image_size*0.8)]
# DataLoader 数据迭代器
ds = FaceMask(path_data,img_size = image_size, cropsize=cropsize, mode='train')
......@@ -139,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
torch.save(state, model_exp+'fp_{}_epoch-{}.pth'.format(image_size,epoch))
if __name__ == "__main__":
image_size = 256
image_size = 512
lr0 = 1e-4
model_exp = './model_exp/'
path_data = './CelebAMask-HQ/'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册