Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
stoneliu1981
pytorch-image-models
提交
3eb4a96e
P
pytorch-image-models
项目概览
stoneliu1981
/
pytorch-image-models
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-image-models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
3eb4a96e
编写于
1月 11, 2020
作者:
R
Ross Wightman
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update AugMix, JSD, etc comments and references
上级
833066b5
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
52 addition
and
8 deletion
+52
-8
timm/data/auto_augment.py
timm/data/auto_augment.py
+21
-4
timm/data/transforms_factory.py
timm/data/transforms_factory.py
+10
-1
timm/loss/jsd.py
timm/loss/jsd.py
+5
-0
timm/models/split_batchnorm.py
timm/models/split_batchnorm.py
+13
-1
train.py
train.py
+3
-2
未找到文件。
timm/data/auto_augment.py
浏览文件 @
3eb4a96e
""" AutoAugment and RandAugment
Implementation adapted from:
""" AutoAugment, RandAugment, and AugMix for PyTorch
This code implements the searched ImageNet policies with various tweaks and improvements and
does not include any of the search code.
AA and RA Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
AugMix adapted from:
https://github.com/google-research/augmix
Papers:
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
Hacked together by Ross Wightman
"""
...
...
@@ -691,12 +703,17 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
class
AugMixAugment
:
""" AugMix Transform
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
https://arxiv.org/abs/1912.02781
"""
def
__init__
(
self
,
ops
,
alpha
=
1.
,
width
=
3
,
depth
=-
1
,
blended
=
False
):
self
.
ops
=
ops
self
.
alpha
=
alpha
self
.
width
=
width
self
.
depth
=
depth
self
.
blended
=
blended
self
.
blended
=
blended
# blended mode is faster but not well tested
def
_calc_blended_weights
(
self
,
ws
,
m
):
ws
=
ws
*
m
...
...
timm/data/transforms_factory.py
浏览文件 @
3eb4a96e
""" Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
"""
import
math
import
torch
...
...
@@ -24,7 +27,13 @@ def transforms_imagenet_train(
re_num_splits
=
0
,
separate
=
False
,
):
"""
If separate==True, the transforms are returned as a tuple of 3 separate transforms
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
primary_tfl
=
[
RandomResizedCropAndInterpolation
(
img_size
,
scale
=
scale
,
interpolation
=
interpolation
),
...
...
timm/loss/jsd.py
浏览文件 @
3eb4a96e
...
...
@@ -8,6 +8,11 @@ from .cross_entropy import LabelSmoothingCrossEntropy
class
JsdCrossEntropy
(
nn
.
Module
):
""" Jensen-Shannon Divergence + Cross-Entropy Loss
Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
https://arxiv.org/abs/1912.02781
Hacked together by Ross Wightman
"""
def
__init__
(
self
,
num_splits
=
3
,
alpha
=
12
,
smoothing
=
0.1
):
super
().
__init__
()
...
...
timm/models/split_batchnorm.py
浏览文件 @
3eb4a96e
""" Split BatchNorm
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
namespace.
This allows easily removing the auxiliary BN layers after training to efficiently
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
'Disentangled Learning via An Auxiliary BN'
Hacked together by Ross Wightman
"""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
SplitBatchNorm2d
(
torch
.
nn
.
BatchNorm2d
):
...
...
train.py
浏览文件 @
3eb4a96e
...
...
@@ -237,8 +237,9 @@ def main():
data_config
=
resolve_data_config
(
vars
(
args
),
model
=
model
,
verbose
=
args
.
local_rank
==
0
)
num_aug_splits
=
0
if
args
.
aug_splits
:
num_aug_splits
=
max
(
args
.
aug_splits
,
2
)
# split of 1 makes no sense
if
args
.
aug_splits
>
0
:
assert
args
.
aug_splits
>
1
,
'A split of 1 makes no sense'
num_aug_splits
=
args
.
aug_splits
if
args
.
split_bn
:
assert
num_aug_splits
>
1
or
args
.
resplit
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录