Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
richardboo
face_parsing
提交
81f8a49d
face_parsing
项目概览
richardboo
/
face_parsing
与 Fork 源项目一致
Fork自
Eric.Lee2021 / face_parsing
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
face_parsing
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
81f8a49d
编写于
2月 24, 2021
作者:
Eric.Lee2021
🚴🏻
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
4126c94c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
11 addition
and
12 deletion
+11
-12
train.py
train.py
+11
-12
未找到文件。
train.py
浏览文件 @
81f8a49d
...
...
@@ -25,14 +25,16 @@ def set_learning_rate(optimizer, lr):
def
train
(
fintune_model
,
image_size
,
lr0
,
path_data
,
model_exp
):
# dataset
# config 训练配置
max_epoch
=
1000
n_classes
=
19
n_img_per_gpu
=
16
n_workers
=
8
cropsize
=
[
int
(
image_size
*
0.85
),
int
(
image_size
*
0.85
)]
# DataLoader 数据迭代器
ds
=
FaceMask
(
path_data
,
img_size
=
image_size
,
cropsize
=
cropsize
,
mode
=
'train'
)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
dl
=
DataLoader
(
ds
,
batch_size
=
n_img_per_gpu
,
shuffle
=
True
,
...
...
@@ -42,19 +44,19 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
# model
ignore_idx
=
-
100
# 构建模型
use_cuda
=
torch
.
cuda
.
is_available
()
device
=
torch
.
device
(
"cuda:0"
if
use_cuda
else
"cpu"
)
net
=
BiSeNet
(
n_classes
=
n_classes
)
net
=
net
.
to
(
device
)
# 加载预训练模型
if
os
.
access
(
fintune_model
,
os
.
F_OK
)
and
(
fintune_model
is
not
None
):
# checkpoint
chkpt
=
torch
.
load
(
fintune_model
,
map_location
=
device
)
net
.
load_state_dict
(
chkpt
)
print
(
'load fintune model : {}'
.
format
(
fintune_model
))
else
:
print
(
'no fintune model'
)
# 构建损失函数
score_thres
=
0.7
n_min
=
n_img_per_gpu
*
cropsize
[
0
]
*
cropsize
[
1
]
//
16
LossP
=
OhemCELoss
(
thresh
=
score_thres
,
n_min
=
n_min
,
ignore_lb
=
ignore_idx
)
...
...
@@ -65,15 +67,14 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
momentum
=
0.9
weight_decay
=
5e-4
lr_start
=
lr0
max_epoch
=
1000
# 构建优化器
optim
=
Optimizer
.
SGD
(
net
.
parameters
(),
lr
=
lr_start
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
#
#
train loop
# train loop
msg_iter
=
50
loss_avg
=
[]
st
=
glob_st
=
time
.
time
()
...
...
@@ -85,7 +86,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
best_loss
=
np
.
inf
loss_mean
=
0.
# 损失均值
loss_idx
=
0.
# 损失计算计数器
# 训练
print
(
'start training ~'
)
it
=
0
for
epoch
in
range
(
max_epoch
):
...
...
@@ -126,12 +127,10 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
loss
.
backward
()
optim
.
step
()
if
it
%
msg_iter
==
0
:
print
(
'epoch <{}/{}> -->> <{}/{}> -> iter {} : loss {:.5f}, loss_mean :{:.5f}, best_loss :{:.5f},lr :{:.6f},batch_size : {}'
.
\
format
(
epoch
,
max_epoch
,
i
,
int
(
ds
.
__len__
()
/
n_img_per_gpu
),
it
,
loss
.
item
(),
loss_mean
/
loss_idx
,
best_loss
,
init_lr
,
n_img_per_gpu
))
# print(msg)
if
(
it
)
%
500
==
0
:
state
=
net
.
module
.
state_dict
()
if
hasattr
(
net
,
'module'
)
else
net
.
state_dict
()
...
...
@@ -140,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp):
torch
.
save
(
state
,
model_exp
+
'fp_{}_epoch-{}.pth'
.
format
(
image_size
,
epoch
))
if
__name__
==
"__main__"
:
image_size
=
512
image_size
=
256
lr0
=
1e-4
model_exp
=
'./model_exp/'
path_data
=
'./CelebAMask-HQ/'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录