Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
stoneliu1981
pytorch-image-models
提交
1577c529
P
pytorch-image-models
项目概览
stoneliu1981
/
pytorch-image-models
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-image-models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
1577c529
编写于
3月 01, 2019
作者:
R
Ross Wightman
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Resnext added, changes to bring it and seresnet in line with rest of models
上级
e0cfeb7d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
277 addition
and
43 deletion
+277
-43
models/model_factory.py
models/model_factory.py
+4
-2
models/resnext.py
models/resnext.py
+194
-0
models/senet.py
models/senet.py
+41
-29
scheduler/cosine_lr.py
scheduler/cosine_lr.py
+18
-3
scheduler/tanh_lr.py
scheduler/tanh_lr.py
+9
-0
train.py
train.py
+11
-9
未找到文件。
models/model_factory.py
浏览文件 @
1577c529
...
...
@@ -10,8 +10,8 @@ from .densenet import densenet161, densenet121, densenet169, densenet201
from
.resnet
import
resnet18
,
resnet34
,
resnet50
,
resnet101
,
resnet152
from
.fbresnet200
import
fbresnet200
from
.dpn
import
dpn68
,
dpn68b
,
dpn92
,
dpn98
,
dpn131
,
dpn107
from
.senet
import
seresnet18
,
seresnet34
,
seresnet50
,
seresnet101
,
seresnet152
,
\
seresnext50_32x4d
,
seresnext101_32x4d
from
.senet
import
seresnet18
,
seresnet34
,
seresnet50
,
seresnet101
,
seresnet152
,
\
seresnext
26_32x4d
,
seresnext
50_32x4d
,
seresnext101_32x4d
from
.resnext
import
resnext50
,
resnext101
,
resnext152
...
...
@@ -112,6 +112,8 @@ def create_model(
model
=
seresnet101
(
num_classes
=
num_classes
,
pretrained
=
pretrained
,
**
kwargs
)
elif
model_name
==
'seresnet152'
:
model
=
seresnet152
(
num_classes
=
num_classes
,
pretrained
=
pretrained
,
**
kwargs
)
elif
model_name
==
'seresnext26_32x4d'
:
model
=
seresnext26_32x4d
(
num_classes
=
num_classes
,
pretrained
=
pretrained
,
**
kwargs
)
elif
model_name
==
'seresnext50_32x4d'
:
model
=
seresnext50_32x4d
(
num_classes
=
num_classes
,
pretrained
=
pretrained
,
**
kwargs
)
elif
model_name
==
'seresnext101_32x4d'
:
...
...
models/resnext.py
0 → 100644
浏览文件 @
1577c529
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
math
import
torch.utils.model_zoo
as
model_zoo
from
models.adaptive_avgmax_pool
import
AdaptiveAvgMaxPool2d
__all__
=
[
'ResNeXt'
,
'resnext50'
,
'resnext101'
,
'resnext152'
]
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
"3x3 convolution with padding"
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
class
ResNeXtBottleneckC
(
nn
.
Module
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
cardinality
=
32
,
base_width
=
4
):
super
(
ResNeXtBottleneckC
,
self
).
__init__
()
width
=
math
.
floor
(
planes
/
64
*
cardinality
*
base_width
)
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
width
,
kernel_size
=
1
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
width
)
self
.
conv2
=
nn
.
Conv2d
(
width
,
width
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
,
groups
=
cardinality
)
self
.
bn2
=
nn
.
BatchNorm2d
(
width
)
self
.
conv3
=
nn
.
Conv2d
(
width
,
planes
*
4
,
kernel_size
=
1
,
bias
=
False
)
self
.
bn3
=
nn
.
BatchNorm2d
(
planes
*
4
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNeXt
(
nn
.
Module
):
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
,
cardinality
=
32
,
base_width
=
4
,
shortcut
=
'C'
,
drop_rate
=
0.
,
global_pool
=
'avg'
):
self
.
num_classes
=
num_classes
self
.
inplanes
=
64
self
.
cardinality
=
cardinality
self
.
base_width
=
base_width
self
.
shortcut
=
shortcut
self
.
drop_rate
=
drop_rate
super
(
ResNeXt
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
)
self
.
avgpool
=
AdaptiveAvgMaxPool2d
(
pool_type
=
global_pool
)
self
.
num_features
=
512
*
block
.
expansion
self
.
fc
=
nn
.
Linear
(
self
.
num_features
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
reshape
=
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
use_conv
=
(
self
.
shortcut
==
'C'
)
or
(
self
.
shortcut
==
'B'
and
reshape
)
if
use_conv
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
)
elif
reshape
:
downsample
=
nn
.
AvgPool2d
(
3
,
stride
=
stride
)
layers
=
[
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
,
self
.
cardinality
,
self
.
base_width
)]
self
.
inplanes
=
planes
*
block
.
expansion
if
self
.
shortcut
==
'C'
:
shortcut
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
)
else
:
shortcut
=
None
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
1
,
shortcut
,
self
.
cardinality
,
self
.
base_width
))
return
nn
.
Sequential
(
*
layers
)
def
get_classifier
(
self
):
return
self
.
fc
def
reset_classifier
(
self
,
num_classes
,
global_pool
=
'avg'
):
self
.
avgpool
=
AdaptiveAvgMaxPool2d
(
pool_type
=
global_pool
)
self
.
num_classes
=
num_classes
del
self
.
fc
if
num_classes
:
self
.
fc
=
nn
.
Linear
(
self
.
num_features
,
num_classes
)
else
:
self
.
fc
=
None
def
forward_features
(
self
,
x
,
pool
=
True
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
if
pool
:
x
=
self
.
avgpool
(
x
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
drop_rate
>
0.
:
x
=
F
.
dropout
(
x
,
p
=
self
.
drop_rate
,
training
=
self
.
training
)
x
=
self
.
fc
(
x
)
return
x
def
resnext50
(
cardinality
=
32
,
base_width
=
4
,
shortcut
=
'C'
,
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNeXt-50 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
"""
model
=
ResNeXt
(
ResNeXtBottleneckC
,
[
3
,
4
,
6
,
3
],
cardinality
=
cardinality
,
base_width
=
base_width
,
shortcut
=
shortcut
,
**
kwargs
)
return
model
def
resnext101
(
cardinality
=
32
,
base_width
=
4
,
shortcut
=
'C'
,
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNeXt-101 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
"""
model
=
ResNeXt
(
ResNeXtBottleneckC
,
[
3
,
4
,
23
,
3
],
cardinality
=
cardinality
,
base_width
=
base_width
,
shortcut
=
shortcut
,
**
kwargs
)
return
model
def
resnext152
(
cardinality
=
32
,
base_width
=
4
,
shortcut
=
'C'
,
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNeXt-152 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
"""
model
=
ResNeXt
(
ResNeXtBottleneckC
,
[
3
,
8
,
36
,
3
],
cardinality
=
cardinality
,
base_width
=
base_width
,
shortcut
=
shortcut
,
**
kwargs
)
return
model
models/senet.py
浏览文件 @
1577c529
...
...
@@ -7,7 +7,9 @@ from collections import OrderedDict
import
math
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.utils
import
model_zoo
from
models.adaptive_avgmax_pool
import
AdaptiveAvgMaxPool2d
__all__
=
[
'SENet'
,
'senet154'
,
'seresnet50'
,
'seresnet101'
,
'seresnet152'
,
'seresnext50_32x4d'
,
'seresnext101_32x4d'
]
...
...
@@ -193,9 +195,9 @@ class SEResNetBlock(nn.Module):
class
SENet
(
nn
.
Module
):
def
__init__
(
self
,
block
,
layers
,
groups
,
reduction
,
drop
out_p
=
0.2
,
def
__init__
(
self
,
block
,
layers
,
groups
,
reduction
,
drop
_rate
=
0.2
,
inchans
=
3
,
inplanes
=
128
,
input_3x3
=
True
,
downsample_kernel_size
=
3
,
downsample_padding
=
1
,
num_classes
=
1000
):
downsample_padding
=
1
,
num_classes
=
1000
,
global_pool
=
'avg'
):
"""
Parameters
----------
...
...
@@ -304,8 +306,8 @@ class SENet(nn.Module):
downsample_kernel_size
=
downsample_kernel_size
,
downsample_padding
=
downsample_padding
)
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
drop
out
=
nn
.
Dropout
(
dropout_p
)
if
dropout_p
is
not
None
else
Non
e
self
.
avg_pool
=
AdaptiveAvgMaxPool2d
(
pool_type
=
global_pool
)
self
.
drop
_rate
=
drop_rat
e
self
.
num_features
=
512
*
block
.
expansion
self
.
last_linear
=
nn
.
Linear
(
self
.
num_features
,
num_classes
)
...
...
@@ -354,8 +356,8 @@ class SENet(nn.Module):
return
x
def
logits
(
self
,
x
):
if
self
.
drop
out
is
not
None
:
x
=
self
.
dropout
(
x
)
if
self
.
drop
_rate
>
0.
:
x
=
F
.
dropout
(
x
,
p
=
self
.
drop_rate
,
training
=
self
.
training
)
x
=
self
.
last_linear
(
x
)
return
x
...
...
@@ -375,79 +377,89 @@ def _load_pretrained(model, url, inchans=3):
model
.
load_state_dict
(
state_dict
)
def
senet154
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
senet154
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEBottleneck
,
[
3
,
8
,
36
,
3
],
groups
=
64
,
reduction
=
16
,
dropout_p
=
0.2
,
num_classes
=
num_classe
s
)
num_classes
=
num_classes
,
**
kwarg
s
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'senet154'
],
inchans
)
return
model
def
seresnet18
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnet18
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNetBlock
,
[
2
,
2
,
2
,
2
],
groups
=
1
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnet18'
],
inchans
)
return
model
def
seresnet34
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnet34
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNetBlock
,
[
3
,
4
,
6
,
3
],
groups
=
1
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnet34'
],
inchans
)
return
model
def
seresnet50
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnet50
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNetBottleneck
,
[
3
,
4
,
6
,
3
],
groups
=
1
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnet50'
],
inchans
)
return
model
def
seresnet101
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnet101
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNetBottleneck
,
[
3
,
4
,
23
,
3
],
groups
=
1
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnet101'
],
inchans
)
return
model
def
seresnet152
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnet152
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNetBottleneck
,
[
3
,
8
,
36
,
3
],
groups
=
1
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnet152'
],
inchans
)
return
model
def
seresnext50_32x4d
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnext26_32x4d
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNeXtBottleneck
,
[
2
,
2
,
2
,
2
],
groups
=
32
,
reduction
=
16
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'se_resnext26_32x4d'
],
inchans
)
return
model
def
seresnext50_32x4d
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNeXtBottleneck
,
[
3
,
4
,
6
,
3
],
groups
=
32
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnext50_32x4d'
],
inchans
)
return
model
def
seresnext101_32x4d
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
):
def
seresnext101_32x4d
(
num_classes
=
1000
,
inchans
=
3
,
pretrained
=
'imagenet'
,
**
kwargs
):
model
=
SENet
(
SEResNeXtBottleneck
,
[
3
,
4
,
23
,
3
],
groups
=
32
,
reduction
=
16
,
dropout_p
=
None
,
inplanes
=
64
,
input_3x3
=
False
,
inplanes
=
64
,
input_3x3
=
False
,
downsample_kernel_size
=
1
,
downsample_padding
=
0
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
model
,
model_urls
[
'seresnext101_32x4d'
],
inchans
)
return
model
scheduler/cosine_lr.py
浏览文件 @
1577c529
...
...
@@ -24,6 +24,7 @@ class CosineLRScheduler(Scheduler):
warmup_t
=
0
,
warmup_lr_init
=
0
,
warmup_prefix
=
False
,
cycle_limit
=
0
,
t_in_epochs
=
True
,
initialize
=
True
)
->
None
:
super
().
__init__
(
optimizer
,
param_group_field
=
"lr"
,
initialize
=
initialize
)
...
...
@@ -37,6 +38,7 @@ class CosineLRScheduler(Scheduler):
self
.
t_mul
=
t_mul
self
.
lr_min
=
lr_min
self
.
decay_rate
=
decay_rate
self
.
cycle_limit
=
cycle_limit
self
.
warmup_t
=
warmup_t
self
.
warmup_lr_init
=
warmup_lr_init
self
.
warmup_prefix
=
warmup_prefix
...
...
@@ -67,9 +69,13 @@ class CosineLRScheduler(Scheduler):
lr_min
=
self
.
lr_min
*
gamma
lr_max_values
=
[
v
*
gamma
for
v
in
self
.
base_values
]
lrs
=
[
lr_min
+
0.5
*
(
lr_max
-
lr_min
)
*
(
1
+
math
.
cos
(
math
.
pi
*
t_curr
/
t_i
))
for
lr_max
in
lr_max_values
]
if
self
.
cycle_limit
==
0
or
(
self
.
cycle_limit
>
0
and
i
<
self
.
cycle_limit
):
lrs
=
[
lr_min
+
0.5
*
(
lr_max
-
lr_min
)
*
(
1
+
math
.
cos
(
math
.
pi
*
t_curr
/
t_i
))
for
lr_max
in
lr_max_values
]
else
:
lrs
=
[
self
.
lr_min
for
_
in
self
.
base_values
]
return
lrs
def
get_epoch_values
(
self
,
epoch
:
int
):
...
...
@@ -83,3 +89,12 @@ class CosineLRScheduler(Scheduler):
return
self
.
_get_lr
(
num_updates
)
else
:
return
None
def
get_cycle_length
(
self
,
cycles
=
0
):
if
not
cycles
:
cycles
=
self
.
cycle_limit
assert
cycles
>
0
if
self
.
t_mul
==
1.0
:
return
self
.
t_initial
*
cycles
else
:
return
int
(
math
.
floor
(
-
self
.
t_initial
*
(
self
.
t_mul
**
cycles
-
1
)
/
(
1
-
self
.
t_mul
)))
scheduler/tanh_lr.py
浏览文件 @
1577c529
...
...
@@ -97,3 +97,12 @@ class TanhLRScheduler(Scheduler):
return
self
.
_get_lr
(
num_updates
)
else
:
return
None
def
get_cycle_length
(
self
,
cycles
=
0
):
if
not
cycles
:
cycles
=
self
.
cycle_limit
assert
cycles
>
0
if
self
.
t_mul
==
1.0
:
return
self
.
t_initial
*
cycles
else
:
return
int
(
math
.
floor
(
-
self
.
t_initial
*
(
self
.
t_mul
**
cycles
-
1
)
/
(
1
-
self
.
t_mul
)))
train.py
浏览文件 @
1577c529
...
...
@@ -91,7 +91,6 @@ def main():
output_dir
=
get_outdir
(
output_base
,
'train'
,
exp_name
)
batch_size
=
args
.
batch_size
num_epochs
=
args
.
epochs
torch
.
manual_seed
(
args
.
seed
)
dataset_train
=
Dataset
(
...
...
@@ -155,9 +154,7 @@ def main():
else
:
model
.
cuda
()
train_loss_fn
=
validate_loss_fn
=
torch
.
nn
.
CrossEntropyLoss
()
train_loss_fn
=
train_loss_fn
.
cuda
()
validate_loss_fn
=
validate_loss_fn
.
cuda
()
train_loss_fn
=
validate_loss_fn
=
torch
.
nn
.
CrossEntropyLoss
().
cuda
()
if
args
.
opt
.
lower
()
==
'sgd'
:
optimizer
=
optim
.
SGD
(
...
...
@@ -183,34 +180,39 @@ def main():
#if optimizer_state is not None:
# optimizer.load_state_dict(optimizer_state)
num_epochs
=
args
.
epochs
if
args
.
sched
==
'cosine'
:
lr_scheduler
=
scheduler
.
CosineLRScheduler
(
optimizer
,
t_initial
=
130
,
t_mul
=
1.
0
,
lr_min
=
0
,
t_initial
=
args
.
epochs
,
t_mul
=
1.
5
,
lr_min
=
1e-5
,
decay_rate
=
args
.
decay_rate
,
warmup_lr_init
=
1e-4
,
warmup_t
=
3
,
cycle_limit
=
3
,
t_in_epochs
=
True
,
)
num_epochs
=
lr_scheduler
.
get_cycle_length
()
+
10
elif
args
.
sched
==
'tanh'
:
lr_scheduler
=
scheduler
.
TanhLRScheduler
(
optimizer
,
t_initial
=
130
,
t_initial
=
args
.
epochs
,
t_mul
=
1.0
,
lr_min
=
1e-
6
,
lr_min
=
1e-
5
,
warmup_lr_init
=
.
001
,
warmup_t
=
3
,
cycle_limit
=
1
,
t_in_epochs
=
True
,
)
num_epochs
=
lr_scheduler
.
get_cycle_length
()
+
10
else
:
lr_scheduler
=
scheduler
.
StepLRScheduler
(
optimizer
,
decay_t
=
args
.
decay_epochs
,
decay_rate
=
args
.
decay_rate
,
)
print
(
num_epochs
)
saver
=
CheckpointSaver
(
checkpoint_dir
=
output_dir
)
best_loss
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录