Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
36aeefcf
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36aeefcf
编写于
9月 17, 2021
作者:
C
cuicheng01
提交者:
GitHub
9月 17, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1064 from TingquanGao/dev/Support_enable_cutmix_mixup
support to enable mixup and cutmix at same time
上级
94433634
b578662b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
76 addition
and
8 deletion
+76
-8
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+3
-1
ppcls/data/preprocess/batch_ops/batch_operators.py
ppcls/data/preprocess/batch_ops/batch_operators.py
+73
-7
未找到文件。
ppcls/data/preprocess/__init__.py
浏览文件 @
36aeefcf
...
...
@@ -29,7 +29,7 @@ from ppcls.data.preprocess.ops.operators import NormalizeImage
from
ppcls.data.preprocess.ops.operators
import
ToCHWImage
from
ppcls.data.preprocess.ops.operators
import
AugMix
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
FmixOperator
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
import
six
import
numpy
as
np
...
...
@@ -45,6 +45,7 @@ def transform(data, ops=[]):
class
AutoAugment
(
RawImageNetPolicy
):
""" ImageNetPolicy wrapper to auto fit different img types """
def
__init__
(
self
,
*
args
,
**
kwargs
):
if
six
.
PY2
:
super
(
AutoAugment
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -69,6 +70,7 @@ class AutoAugment(RawImageNetPolicy):
class
RandAugment
(
RawRandAugment
):
""" RandAugment wrapper to auto fit different img types """
def
__init__
(
self
,
*
args
,
**
kwargs
):
if
six
.
PY2
:
super
(
RandAugment
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
...
ppcls/data/preprocess/batch_ops/batch_operators.py
浏览文件 @
36aeefcf
...
...
@@ -16,13 +16,17 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
random
import
numpy
as
np
from
ppcls.utils
import
logger
from
ppcls.data.preprocess.ops.fmix
import
sample_mask
class
BatchOperator
(
object
):
""" BatchOperator """
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
...
...
@@ -46,9 +50,20 @@ class BatchOperator(object):
class
MixupOperator
(
BatchOperator
):
""" Mixup operator """
def
__init__
(
self
,
alpha
=
0.2
):
assert
alpha
>
0.
,
\
'parameter alpha[%f] should > 0.0'
%
(
alpha
)
def
__init__
(
self
,
alpha
:
float
=
1.
):
"""Build Mixup operator
Args:
alpha (float, optional): The parameter alpha of mixup. Defaults to 1..
Raises:
Exception: The value of parameter is illegal.
"""
if
alpha
<=
0
:
raise
Exception
(
f
"Parameter
\"
alpha
\"
of Mixup should be greater than 0.
\"
alpha
\"
:
{
alpha
}
."
)
self
.
_alpha
=
alpha
def
__call__
(
self
,
batch
):
...
...
@@ -62,9 +77,20 @@ class MixupOperator(BatchOperator):
class
CutmixOperator
(
BatchOperator
):
""" Cutmix operator """
def
__init__
(
self
,
alpha
=
0.2
):
assert
alpha
>
0.
,
\
'parameter alpha[%f] should > 0.0'
%
(
alpha
)
"""Build Cutmix operator
Args:
alpha (float, optional): The parameter alpha of cutmix. Defaults to 0.2.
Raises:
Exception: The value of parameter is illegal.
"""
if
alpha
<=
0
:
raise
Exception
(
f
"Parameter
\"
alpha
\"
of Cutmix should be greater than 0.
\"
alpha
\"
:
{
alpha
}
."
)
self
.
_alpha
=
alpha
def
_rand_bbox
(
self
,
size
,
lam
):
...
...
@@ -72,8 +98,8 @@ class CutmixOperator(BatchOperator):
w
=
size
[
2
]
h
=
size
[
3
]
cut_rat
=
np
.
sqrt
(
1.
-
lam
)
cut_w
=
np
.
int
(
w
*
cut_rat
)
cut_h
=
np
.
int
(
h
*
cut_rat
)
cut_w
=
int
(
w
*
cut_rat
)
cut_h
=
int
(
h
*
cut_rat
)
# uniform
cx
=
np
.
random
.
randint
(
w
)
...
...
@@ -101,6 +127,7 @@ class CutmixOperator(BatchOperator):
class
FmixOperator
(
BatchOperator
):
""" Fmix operator """
def
__init__
(
self
,
alpha
=
1
,
decay_power
=
3
,
max_soft
=
0.
,
reformulate
=
False
):
self
.
_alpha
=
alpha
self
.
_decay_power
=
decay_power
...
...
@@ -115,3 +142,42 @@ class FmixOperator(BatchOperator):
size
,
self
.
_max_soft
,
self
.
_reformulate
)
imgs
=
mask
*
imgs
+
(
1
-
mask
)
*
imgs
[
idx
]
return
list
(
zip
(
imgs
,
labels
,
labels
[
idx
],
[
lam
]
*
bs
))
class
OpSampler
(
object
):
""" Sample a operator from """
def
__init__
(
self
,
**
op_dict
):
"""Build OpSampler
Raises:
Exception: The parameter
\"
prob
\"
of operator(s) are be set error.
"""
if
len
(
op_dict
)
<
1
:
msg
=
f
"ConfigWarning: No operator in
\"
OpSampler
\"
.
\"
OpSampler
\"
has been skipped."
self
.
ops
=
{}
total_prob
=
0
for
op_name
in
op_dict
:
param
=
op_dict
[
op_name
]
if
"prob"
not
in
param
:
msg
=
f
"ConfigWarning: Parameter
\"
prob
\"
should be set when use operator in
\"
OpSampler
\"
. The operator
\"
{
op_name
}
\"
's prob has been set
\"
0
\"
."
logger
.
warning
(
msg
)
prob
=
param
.
pop
(
"prob"
,
0
)
total_prob
+=
prob
op
=
eval
(
op_name
)(
**
param
)
self
.
ops
.
update
({
op
:
prob
})
if
total_prob
>
1
:
msg
=
f
"ConfigError: The total prob of operators in
\"
OpSampler
\"
should be less 1."
logger
.
error
(
msg
)
raise
Exception
(
msg
)
# add "None Op" when total_prob < 1, "None Op" do nothing
self
.
ops
[
None
]
=
1
-
total_prob
def
__call__
(
self
,
batch
):
op
=
random
.
choices
(
list
(
self
.
ops
.
keys
()),
weights
=
list
(
self
.
ops
.
values
()),
k
=
1
)[
0
]
# return batch directly when None Op
return
op
(
batch
)
if
op
else
batch
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录