Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
79666238
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79666238
编写于
10月 08, 2021
作者:
G
gaotingquan
提交者:
Tingquan Gao
10月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor: adapt to static graph in deprecating MixCELoss
上级
873869dd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
26 addition
and
17 deletion
+26
-17
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-0
ppcls/static/program.py
ppcls/static/program.py
+22
-17
ppcls/static/train.py
ppcls/static/train.py
+3
-0
未找到文件。
ppcls/engine/engine.py
浏览文件 @
79666238
...
@@ -112,6 +112,7 @@ class Engine(object):
...
@@ -112,6 +112,7 @@ class Engine(object):
}
}
paddle
.
fluid
.
set_flags
(
AMP_RELATED_FLAGS_SETTING
)
paddle
.
fluid
.
set_flags
(
AMP_RELATED_FLAGS_SETTING
)
#TODO(gaotingquan): support rec
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
# build dataloader
# build dataloader
...
...
ppcls/static/program.py
浏览文件 @
79666238
...
@@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter
...
@@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter
from
ppcls.utils
import
logger
,
profiler
from
ppcls.utils
import
logger
,
profiler
def
create_feeds
(
image_shape
,
use_mix
=
None
,
dtype
=
"float32"
):
def
create_feeds
(
image_shape
,
use_mix
=
False
,
class_num
=
None
,
dtype
=
"float32"
):
"""
"""
Create feeds as model input
Create feeds as model input
Args:
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
class_num(int): the class number of network, required if use_mix
Returns:
Returns:
feeds(dict): dict of model input variables
feeds(dict): dict of model input variables
...
@@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"):
...
@@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"):
feeds
=
OrderedDict
()
feeds
=
OrderedDict
()
feeds
[
'data'
]
=
paddle
.
static
.
data
(
feeds
[
'data'
]
=
paddle
.
static
.
data
(
name
=
"data"
,
shape
=
[
None
]
+
image_shape
,
dtype
=
dtype
)
name
=
"data"
,
shape
=
[
None
]
+
image_shape
,
dtype
=
dtype
)
if
use_mix
:
if
use_mix
:
feeds
[
'y_a'
]
=
paddle
.
static
.
data
(
if
class_num
is
None
:
name
=
"y_a"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
msg
=
"When use MixUp, CutMix and so on, you must set class_num."
feeds
[
'y_b'
]
=
paddle
.
static
.
data
(
logger
.
error
(
msg
)
name
=
"y_b"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
raise
Exception
(
msg
)
feeds
[
'
lam
'
]
=
paddle
.
static
.
data
(
feeds
[
'
target
'
]
=
paddle
.
static
.
data
(
name
=
"
lam"
,
shape
=
[
None
,
1
],
dtype
=
dtype
)
name
=
"
target"
,
shape
=
[
None
,
class_num
],
dtype
=
"float32"
)
else
:
else
:
feeds
[
'label'
]
=
paddle
.
static
.
data
(
feeds
[
'label'
]
=
paddle
.
static
.
data
(
name
=
"label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
name
=
"label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
...
@@ -74,6 +76,7 @@ def create_fetchs(out,
...
@@ -74,6 +76,7 @@ def create_fetchs(out,
architecture
,
architecture
,
topk
=
5
,
topk
=
5
,
epsilon
=
None
,
epsilon
=
None
,
class_num
=
None
,
use_mix
=
False
,
use_mix
=
False
,
config
=
None
,
config
=
None
,
mode
=
"Train"
):
mode
=
"Train"
):
...
@@ -88,6 +91,7 @@ def create_fetchs(out,
...
@@ -88,6 +91,7 @@ def create_fetchs(out,
name(such as ResNet50) is needed
name(such as ResNet50) is needed
topk(int): usually top5
topk(int): usually top5
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
class_num(int): the class number of network, required if use_mix
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
config(dict): model config
config(dict): model config
...
@@ -97,18 +101,16 @@ def create_fetchs(out,
...
@@ -97,18 +101,16 @@ def create_fetchs(out,
fetchs
=
OrderedDict
()
fetchs
=
OrderedDict
()
# build loss
# build loss
if
use_mix
:
if
use_mix
:
y_a
=
paddle
.
reshape
(
feeds
[
'y_a'
],
[
-
1
,
1
])
if
class_num
is
None
:
y_b
=
paddle
.
reshape
(
feeds
[
'y_b'
],
[
-
1
,
1
])
msg
=
"When use MixUp, CutMix and so on, you must set class_num."
lam
=
paddle
.
reshape
(
feeds
[
'lam'
],
[
-
1
,
1
])
logger
.
error
(
msg
)
raise
Exception
(
msg
)
target
=
paddle
.
reshape
(
feeds
[
'target'
],
[
-
1
,
class_num
])
else
:
else
:
target
=
paddle
.
reshape
(
feeds
[
'label'
],
[
-
1
,
1
])
target
=
paddle
.
reshape
(
feeds
[
'label'
],
[
-
1
,
1
])
loss_func
=
build_loss
(
config
[
"Loss"
][
mode
])
loss_func
=
build_loss
(
config
[
"Loss"
][
mode
])
loss_dict
=
loss_func
(
out
,
target
)
if
use_mix
:
loss_dict
=
loss_func
(
out
,
[
y_a
,
y_b
,
lam
])
else
:
loss_dict
=
loss_func
(
out
,
target
)
loss_out
=
loss_dict
[
"loss"
]
loss_out
=
loss_dict
[
"loss"
]
fetchs
[
'loss'
]
=
(
loss_out
,
AverageMeter
(
'loss'
,
'7.4f'
,
need_avg
=
True
))
fetchs
[
'loss'
]
=
(
loss_out
,
AverageMeter
(
'loss'
,
'7.4f'
,
need_avg
=
True
))
...
@@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer):
...
@@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer):
def
build
(
config
,
def
build
(
config
,
main_prog
,
main_prog
,
startup_prog
,
startup_prog
,
class_num
=
None
,
step_each_epoch
=
100
,
step_each_epoch
=
100
,
is_train
=
True
,
is_train
=
True
,
is_distributed
=
True
):
is_distributed
=
True
):
...
@@ -233,6 +236,7 @@ def build(config,
...
@@ -233,6 +236,7 @@ def build(config,
config(dict): config
config(dict): config
main_prog(): main program
main_prog(): main program
startup_prog(): startup program
startup_prog(): startup program
class_num(int): the class number of network, required if use_mix
is_train(bool): train or eval
is_train(bool): train or eval
is_distributed(bool): whether to use distributed training method
is_distributed(bool): whether to use distributed training method
...
@@ -245,10 +249,10 @@ def build(config,
...
@@ -245,10 +249,10 @@ def build(config,
mode
=
"Train"
if
is_train
else
"Eval"
mode
=
"Train"
if
is_train
else
"Eval"
use_mix
=
"batch_transform_ops"
in
config
[
"DataLoader"
][
mode
][
use_mix
=
"batch_transform_ops"
in
config
[
"DataLoader"
][
mode
][
"dataset"
]
"dataset"
]
use_dali
=
config
[
"Global"
].
get
(
'use_dali'
,
False
)
feeds
=
create_feeds
(
feeds
=
create_feeds
(
config
[
"Global"
][
"image_shape"
],
config
[
"Global"
][
"image_shape"
],
use_mix
=
use_mix
,
use_mix
,
class_num
=
class_num
,
dtype
=
"float32"
)
dtype
=
"float32"
)
# build model
# build model
...
@@ -264,6 +268,7 @@ def build(config,
...
@@ -264,6 +268,7 @@ def build(config,
feeds
,
feeds
,
config
[
"Arch"
],
config
[
"Arch"
],
epsilon
=
config
.
get
(
'ls_epsilon'
),
epsilon
=
config
.
get
(
'ls_epsilon'
),
class_num
=
class_num
,
use_mix
=
use_mix
,
use_mix
=
use_mix
,
config
=
config
,
config
=
config
,
mode
=
mode
)
mode
=
mode
)
...
...
ppcls/static/train.py
浏览文件 @
79666238
...
@@ -112,6 +112,8 @@ def main(args):
...
@@ -112,6 +112,8 @@ def main(args):
eval_dataloader
=
None
eval_dataloader
=
None
use_dali
=
global_config
.
get
(
'use_dali'
,
False
)
use_dali
=
global_config
.
get
(
'use_dali'
,
False
)
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
train_dataloader
=
build_dataloader
(
train_dataloader
=
build_dataloader
(
config
[
"DataLoader"
],
"Train"
,
device
=
device
,
use_dali
=
use_dali
)
config
[
"DataLoader"
],
"Train"
,
device
=
device
,
use_dali
=
use_dali
)
if
global_config
[
"eval_during_train"
]:
if
global_config
[
"eval_during_train"
]:
...
@@ -131,6 +133,7 @@ def main(args):
...
@@ -131,6 +133,7 @@ def main(args):
config
,
config
,
train_prog
,
train_prog
,
startup_prog
,
startup_prog
,
class_num
,
step_each_epoch
=
step_each_epoch
,
step_each_epoch
=
step_each_epoch
,
is_train
=
True
,
is_train
=
True
,
is_distributed
=
global_config
.
get
(
"is_distributed"
,
True
))
is_distributed
=
global_config
.
get
(
"is_distributed"
,
True
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录