Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
__zzh__
classification
提交
fbbb98d2
classification
项目概览
__zzh__
/
classification
与 Fork 源项目一致
Fork自
DataBall / classification
通知
12
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
classification
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fbbb98d2
编写于
4月 23, 2021
作者:
DataBall
🚴🏻
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
correct random seed bug
上级
356e8010
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
12 addition
and
9 deletion
+12
-9
train.py
train.py
+12
-9
未找到文件。
train.py
浏览文件 @
fbbb98d2
...
@@ -21,7 +21,7 @@ import cv2
...
@@ -21,7 +21,7 @@ import cv2
import
time
import
time
import
json
import
json
from
datetime
import
datetime
from
datetime
import
datetime
import
random
def
tester
(
ops
,
epoch
,
model
,
criterion
,
def
tester
(
ops
,
epoch
,
model
,
criterion
,
train_split
,
train_split_label
,
val_split
,
val_split_label
,
train_split
,
train_split_label
,
val_split
,
val_split_label
,
use_cuda
):
use_cuda
):
...
@@ -190,7 +190,7 @@ def trainer(ops,f_log):
...
@@ -190,7 +190,7 @@ def trainer(ops,f_log):
else
:
else
:
flag_change_lr_cnt
+=
1
flag_change_lr_cnt
+=
1
if
flag_change_lr_cnt
>
5
:
if
flag_change_lr_cnt
>
10
:
init_lr
=
init_lr
*
ops
.
lr_decay
init_lr
=
init_lr
*
ops
.
lr_decay
set_learning_rate
(
optimizer
,
init_lr
)
set_learning_rate
(
optimizer
,
init_lr
)
flag_change_lr_cnt
=
0
flag_change_lr_cnt
=
0
...
@@ -226,7 +226,8 @@ def trainer(ops,f_log):
...
@@ -226,7 +226,8 @@ def trainer(ops,f_log):
step
+=
1
step
+=
1
# 一个 epoch 保存连词最新的 模型
# 一个 epoch 保存连词最新的 模型
if
i
%
(
int
(
dataset
.
__len__
()
/
ops
.
batch_size
/
2
-
1
))
==
0
and
i
>
0
:
# if i%(int(dataset.__len__()/ops.batch_size/2-1)) == 0 and i > 0:
if
i
%
(
1000
)
==
0
and
i
>
0
:
torch
.
save
(
model_
.
state_dict
(),
ops
.
model_exp
+
'latest.pth'
)
torch
.
save
(
model_
.
state_dict
(),
ops
.
model_exp
+
'latest.pth'
)
# 每间隔 5 个 epoch 进行模型保存
# 每间隔 5 个 epoch 进行模型保存
if
(
epoch
%
5
)
==
0
and
(
epoch
>
9
):
if
(
epoch
%
5
)
==
0
and
(
epoch
>
9
):
...
@@ -248,6 +249,8 @@ def trainer(ops,f_log):
...
@@ -248,6 +249,8 @@ def trainer(ops,f_log):
json
.
dump
(
epochs_loss_dict
,
f_loss
,
ensure_ascii
=
False
,
indent
=
1
,
cls
=
JSON_Encoder
)
json
.
dump
(
epochs_loss_dict
,
f_loss
,
ensure_ascii
=
False
,
indent
=
1
,
cls
=
JSON_Encoder
)
f_loss
.
close
()
f_loss
.
close
()
set_seed
(
random
.
randint
(
0
,
65535
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
'Exception : '
,
e
)
# 打印异常
print
(
'Exception : '
,
e
)
# 打印异常
print
(
'Exception file : '
,
e
.
__traceback__
.
tb_frame
.
f_globals
[
'__file__'
])
# 发生异常所在的文件
print
(
'Exception file : '
,
e
.
__traceback__
.
tb_frame
.
f_globals
[
'__file__'
])
# 发生异常所在的文件
...
@@ -260,16 +263,16 @@ if __name__ == "__main__":
...
@@ -260,16 +263,16 @@ if __name__ == "__main__":
help
=
'seed'
)
# 设置随机种子
help
=
'seed'
)
# 设置随机种子
parser
.
add_argument
(
'--model_exp'
,
type
=
str
,
default
=
'./model_exp'
,
parser
.
add_argument
(
'--model_exp'
,
type
=
str
,
default
=
'./model_exp'
,
help
=
'model_exp'
)
# 模型输出文件夹
help
=
'model_exp'
)
# 模型输出文件夹
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'resnet_
34
'
,
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'resnet_
50
'
,
help
=
'model : resnet_18,resnet_34,resnet_50,resnet_101,resnet_152'
)
# 模型类型
help
=
'model : resnet_18,resnet_34,resnet_50,resnet_101,resnet_152'
)
# 模型类型
'''
'''
注意以下3个参数与具体分类任务数据集,息息相关
注意以下3个参数与具体分类任务数据集,息息相关
'''
'''
#---------------------------------------------------------------------------------
#---------------------------------------------------------------------------------
parser
.
add_argument
(
'--train_path'
,
type
=
str
,
default
=
'./
handpose_x_gesture_v1
/'
,
parser
.
add_argument
(
'--train_path'
,
type
=
str
,
default
=
'./
animals10
/'
,
help
=
'train_path'
)
# 训练集路径
help
=
'train_path'
)
# 训练集路径
parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
1
4
,
parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
1
0
,
help
=
'num_classes'
)
# 分类类别个数,gesture 配置为 14 , Stanford Dogs 配置为 120
help
=
'num_classes'
)
# 分类类别个数,gesture 配置为 14 , Stanford Dogs 配置为 120
parser
.
add_argument
(
'--have_label_file'
,
type
=
bool
,
default
=
False
,
parser
.
add_argument
(
'--have_label_file'
,
type
=
bool
,
default
=
False
,
help
=
'have_label_file'
)
# 是否有配套的标注文件解析才能生成分类样本,gesture 配置为 False , Stanford Dogs 配置为 True
help
=
'have_label_file'
)
# 是否有配套的标注文件解析才能生成分类样本,gesture 配置为 False , Stanford Dogs 配置为 True
...
@@ -293,15 +296,15 @@ if __name__ == "__main__":
...
@@ -293,15 +296,15 @@ if __name__ == "__main__":
help
=
'learningRate_decay'
)
# 学习率权重衰减率
help
=
'learningRate_decay'
)
# 学习率权重衰减率
parser
.
add_argument
(
'--weight_decay'
,
type
=
float
,
default
=
1e-6
,
parser
.
add_argument
(
'--weight_decay'
,
type
=
float
,
default
=
1e-6
,
help
=
'weight_decay'
)
# 优化器正则损失权重
help
=
'weight_decay'
)
# 优化器正则损失权重
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
48
,
help
=
'batch_size'
)
# 训练每批次图像数量
help
=
'batch_size'
)
# 训练每批次图像数量
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
default
=
0.5
,
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
default
=
0.5
,
help
=
'dropout'
)
# dropout
help
=
'dropout'
)
# dropout
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
1000
,
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
1000
,
help
=
'epochs'
)
# 训练周期
help
=
'epochs'
)
# 训练周期
parser
.
add_argument
(
'--num_workers'
,
type
=
int
,
default
=
1
,
parser
.
add_argument
(
'--num_workers'
,
type
=
int
,
default
=
6
,
help
=
'num_workers'
)
# 训练数据生成器线程数
help
=
'num_workers'
)
# 训练数据生成器线程数
parser
.
add_argument
(
'--img_size'
,
type
=
tuple
,
default
=
(
192
,
192
),
parser
.
add_argument
(
'--img_size'
,
type
=
tuple
,
default
=
(
256
,
256
),
help
=
'img_size'
)
# 输入模型图片尺寸
help
=
'img_size'
)
# 输入模型图片尺寸
parser
.
add_argument
(
'--flag_agu'
,
type
=
bool
,
default
=
True
,
parser
.
add_argument
(
'--flag_agu'
,
type
=
bool
,
default
=
True
,
help
=
'data_augmentation'
)
# 训练数据生成器是否进行数据扩增
help
=
'data_augmentation'
)
# 训练数据生成器是否进行数据扩增
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录