diff --git a/README.md b/README.md index ab1622d7df3f36ef0c880048df64524616b5bca8..0c1b8753119c151f4bb43611254185742bbf8fa6 100644 --- a/README.md +++ b/README.md @@ -30,12 +30,16 @@ 下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。 ## 预训练模型 -* [预训练模型下载地址(百度网盘 Password: )]() +* [预训练模型下载地址(百度网盘 Password: ri6m )](https://pan.baidu.com/s/1I5fPAyXDfIh9M5POs80ovg) ## 项目使用方法 -### 模型训练 +### 步骤1:生成训练数据 +* 目前支持样本分辨率 512,256,注意训练、推理脚本也要做相应的分辨率对应设置。 +* 运行脚本:prepropess_data.py (注意脚本内相关参数配置 ) + +### 步骤2:模型训练 * 根目录下运行命令: python train.py (注意脚本内相关参数配置 ) -### 模型推理 +### 步骤3:模型推理 * 根目录下运行命令: python inference.py (注意脚本内相关参数配置 ) diff --git a/inference.py b/inference.py index 302f35bac6fd53e32c83f852198258262d86af22..05980a2163ff5269f5ffe5c4772b273c95086ea8 100644 --- a/inference.py +++ b/inference.py @@ -83,7 +83,7 @@ def inference( img_size, image_path, model_path): img_ = cv2.imread(image_path + f_) img = Image.fromarray(cv2.cvtColor(img_,cv2.COLOR_BGR2RGB)) - image = img.resize((img_size, img_size), Image.BILINEAR) + image = img.resize((img_size, img_size)) img = to_tensor(image) img = torch.unsqueeze(img, 0) img = img.cuda() @@ -93,6 +93,7 @@ def inference( img_size, image_path, model_path): print('<{}> image : '.format(idx),np.unique(parsing_)) parsing_ = cv2.resize(parsing_,(img_.shape[1],img_.shape[0]),interpolation=cv2.INTER_NEAREST) + parsing_ = parsing_.astype(np.uint8) vis_im = vis_parsing_maps(img_, parsing_, 0,0,stride=1) @@ -104,10 +105,11 @@ def inference( img_size, image_path, model_path): cv2.namedWindow("vis_im",0) cv2.imshow("vis_im",vis_im) + if cv2.waitKey(500) == 27: break if __name__ == "__main__": - img_size = 512 - model_path = "./model_exp/2021-02-23_22-03-22/fp_latest.pth" + img_size = 512 # 推理分辨率设置 + model_path = "./weights/fp_512.pth" # 模型路径 image_path = "./images/" inference(img_size = img_size, image_path=image_path, model_path=model_path) diff --git a/train.py b/train.py index a619856321559ce8f199f7bb1c68fa59b32adbf0..e08cecef24cb95793316178e2fa8278916ce75c8 100644 --- a/train.py +++ b/train.py @@ -29,8 +29,8 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): max_epoch = 1000 n_classes = 19 n_img_per_gpu = 16 - n_workers = 8 - cropsize = [int(image_size*0.85),int(image_size*0.85)] + n_workers = 12 + cropsize = [int(image_size*0.8),int(image_size*0.8)] # DataLoader 数据迭代器 ds = FaceMask(path_data,img_size = image_size, cropsize=cropsize, mode='train') @@ -139,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): torch.save(state, model_exp+'fp_{}_epoch-{}.pth'.format(image_size,epoch)) if __name__ == "__main__": - image_size = 256 + image_size = 512 lr0 = 1e-4 model_exp = './model_exp/' path_data = './CelebAMask-HQ/'