Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
DataBall
face_parsing
提交
8abced85
face_parsing
项目概览
DataBall
/
face_parsing
通知
625
Star
32
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
face_parsing
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8abced85
编写于
2月 25, 2021
作者:
DataBall
🚴🏻
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
8652bb60
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
15 addition
and
9 deletion
+15
-9
README.md
README.md
+7
-3
inference.py
inference.py
+5
-3
train.py
train.py
+3
-3
未找到文件。
README.md
浏览文件 @
8abced85
...
@@ -30,12 +30,16 @@
...
@@ -30,12 +30,16 @@
下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。
下载数据集并解压,然后运行脚本 prepropess_data.py,生成训练用的mask,注意脚本内相关参数配置。
## 预训练模型
## 预训练模型
*
[
预训练模型下载地址(百度网盘 Password:
)
](
)
*
[
预训练模型下载地址(百度网盘 Password:
ri6m )
](
https://pan.baidu.com/s/1I5fPAyXDfIh9M5POs80ovg
)
## 项目使用方法
## 项目使用方法
### 模型训练
### 步骤1:生成训练数据
*
目前支持样本分辨率 512,256,注意训练、推理脚本也要做相应的分辨率对应设置。
*
运行脚本:prepropess_data.py (注意脚本内相关参数配置 )
### 步骤2:模型训练
*
根目录下运行命令: python train.py (注意脚本内相关参数配置 )
*
根目录下运行命令: python train.py (注意脚本内相关参数配置 )
### 模型推理
###
步骤3:
模型推理
*
根目录下运行命令: python inference.py (注意脚本内相关参数配置 )
*
根目录下运行命令: python inference.py (注意脚本内相关参数配置 )
inference.py
浏览文件 @
8abced85
...
@@ -83,7 +83,7 @@ def inference( img_size, image_path, model_path):
...
@@ -83,7 +83,7 @@ def inference( img_size, image_path, model_path):
img_
=
cv2
.
imread
(
image_path
+
f_
)
img_
=
cv2
.
imread
(
image_path
+
f_
)
img
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img_
,
cv2
.
COLOR_BGR2RGB
))
img
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img_
,
cv2
.
COLOR_BGR2RGB
))
image
=
img
.
resize
((
img_size
,
img_size
)
,
Image
.
BILINEAR
)
image
=
img
.
resize
((
img_size
,
img_size
))
img
=
to_tensor
(
image
)
img
=
to_tensor
(
image
)
img
=
torch
.
unsqueeze
(
img
,
0
)
img
=
torch
.
unsqueeze
(
img
,
0
)
img
=
img
.
cuda
()
img
=
img
.
cuda
()
...
@@ -93,6 +93,7 @@ def inference( img_size, image_path, model_path):
...
@@ -93,6 +93,7 @@ def inference( img_size, image_path, model_path):
print
(
'<{}> image : '
.
format
(
idx
),
np
.
unique
(
parsing_
))
print
(
'<{}> image : '
.
format
(
idx
),
np
.
unique
(
parsing_
))
parsing_
=
cv2
.
resize
(
parsing_
,(
img_
.
shape
[
1
],
img_
.
shape
[
0
]),
interpolation
=
cv2
.
INTER_NEAREST
)
parsing_
=
cv2
.
resize
(
parsing_
,(
img_
.
shape
[
1
],
img_
.
shape
[
0
]),
interpolation
=
cv2
.
INTER_NEAREST
)
parsing_
=
parsing_
.
astype
(
np
.
uint8
)
parsing_
=
parsing_
.
astype
(
np
.
uint8
)
vis_im
=
vis_parsing_maps
(
img_
,
parsing_
,
0
,
0
,
stride
=
1
)
vis_im
=
vis_parsing_maps
(
img_
,
parsing_
,
0
,
0
,
stride
=
1
)
...
@@ -104,10 +105,11 @@ def inference( img_size, image_path, model_path):
...
@@ -104,10 +105,11 @@ def inference( img_size, image_path, model_path):
cv2
.
namedWindow
(
"vis_im"
,
0
)
cv2
.
namedWindow
(
"vis_im"
,
0
)
cv2
.
imshow
(
"vis_im"
,
vis_im
)
cv2
.
imshow
(
"vis_im"
,
vis_im
)
if
cv2
.
waitKey
(
500
)
==
27
:
if
cv2
.
waitKey
(
500
)
==
27
:
break
break
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
img_size
=
512
img_size
=
512
# 推理分辨率设置
model_path
=
"./
model_exp/2021-02-23_22-03-22/fp_latest.pth"
model_path
=
"./
weights/fp_512.pth"
# 模型路径
image_path
=
"./images/"
image_path
=
"./images/"
inference
(
img_size
=
img_size
,
image_path
=
image_path
,
model_path
=
model_path
)
inference
(
img_size
=
img_size
,
image_path
=
image_path
,
model_path
=
model_path
)
train.py
浏览文件 @
8abced85
...
@@ -29,8 +29,8 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
...
@@ -29,8 +29,8 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
max_epoch
=
1000
max_epoch
=
1000
n_classes
=
19
n_classes
=
19
n_img_per_gpu
=
16
n_img_per_gpu
=
16
n_workers
=
8
n_workers
=
12
cropsize
=
[
int
(
image_size
*
0.8
5
),
int
(
image_size
*
0.85
)]
cropsize
=
[
int
(
image_size
*
0.8
),
int
(
image_size
*
0.8
)]
# DataLoader 数据迭代器
# DataLoader 数据迭代器
ds
=
FaceMask
(
path_data
,
img_size
=
image_size
,
cropsize
=
cropsize
,
mode
=
'train'
)
ds
=
FaceMask
(
path_data
,
img_size
=
image_size
,
cropsize
=
cropsize
,
mode
=
'train'
)
...
@@ -139,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
...
@@ -139,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
torch
.
save
(
state
,
model_exp
+
'fp_{}_epoch-{}.pth'
.
format
(
image_size
,
epoch
))
torch
.
save
(
state
,
model_exp
+
'fp_{}_epoch-{}.pth'
.
format
(
image_size
,
epoch
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
image_size
=
256
image_size
=
512
lr0
=
1e-4
lr0
=
1e-4
model_exp
=
'./model_exp/'
model_exp
=
'./model_exp/'
path_data
=
'./CelebAMask-HQ/'
path_data
=
'./CelebAMask-HQ/'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录