Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
da9530f7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
da9530f7
编写于
6月 13, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2090 resnet quant dataset aug change
Merge pull request !2090 from panfengfeng/resnet_quant_data_aug_change
上级
65351963
690db9a5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
67 addition
and
6 deletion
+67
-6
example/resnet50_quant/eval.py
example/resnet50_quant/eval.py
+3
-3
example/resnet50_quant/src/dataset.py
example/resnet50_quant/src/dataset.py
+61
-0
example/resnet50_quant/train.py
example/resnet50_quant/train.py
+3
-3
未找到文件。
example/resnet50_quant/eval.py
浏览文件 @
da9530f7
...
...
@@ -17,7 +17,7 @@ eval.
"""
import
os
import
argparse
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
_py
from
src.config
import
config
from
src.crossentropy
import
CrossEntropy
from
src.utils
import
_load_param_into_net
...
...
@@ -49,8 +49,8 @@ if __name__ == '__main__':
loss
=
CrossEntropy
(
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
if
args_opt
.
do_eval
:
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
False
,
batch_size
=
config
.
batch_size
,
target
=
target
)
dataset
=
create_dataset
_py
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
False
,
batch_size
=
config
.
batch_size
,
target
=
target
)
step_size
=
dataset
.
get_dataset_size
()
if
args_opt
.
checkpoint_path
:
...
...
example/resnet50_quant/src/dataset.py
浏览文件 @
da9530f7
...
...
@@ -20,6 +20,7 @@ import mindspore.common.dtype as mstype
import
mindspore.dataset.engine
as
de
import
mindspore.dataset.transforms.vision.c_transforms
as
C
import
mindspore.dataset.transforms.c_transforms
as
C2
import
mindspore.dataset.transforms.vision.py_transforms
as
P
from
mindspore.communication.management
import
init
,
get_rank
,
get_group_size
def
create_dataset
(
dataset_path
,
do_train
,
repeat_num
=
1
,
batch_size
=
32
,
target
=
"Ascend"
):
...
...
@@ -83,3 +84,63 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
ds
=
ds
.
repeat
(
repeat_num
)
return
ds
def
create_dataset_py
(
dataset_path
,
do_train
,
repeat_num
=
1
,
batch_size
=
32
,
target
=
"Ascend"
):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if
target
==
"Ascend"
:
device_num
=
int
(
os
.
getenv
(
"RANK_SIZE"
))
rank_id
=
int
(
os
.
getenv
(
"RANK_ID"
))
else
:
init
(
"nccl"
)
rank_id
=
get_rank
()
device_num
=
get_group_size
()
if
do_train
:
if
device_num
==
1
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
)
else
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
,
num_shards
=
device_num
,
shard_id
=
rank_id
)
else
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
False
)
image_size
=
224
# define map operations
decode_op
=
P
.
Decode
()
resize_crop_op
=
P
.
RandomResizedCrop
(
image_size
,
scale
=
(
0.08
,
1.0
),
ratio
=
(
0.75
,
1.333
))
horizontal_flip_op
=
P
.
RandomHorizontalFlip
(
prob
=
0.5
)
resize_op
=
P
.
Resize
(
256
)
center_crop
=
P
.
CenterCrop
(
image_size
)
to_tensor
=
P
.
ToTensor
()
normalize_op
=
P
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
# define map operations
if
do_train
:
trans
=
[
decode_op
,
resize_crop_op
,
horizontal_flip_op
,
to_tensor
,
normalize_op
]
else
:
trans
=
[
decode_op
,
resize_op
,
center_crop
,
to_tensor
,
normalize_op
]
compose
=
P
.
ComposeOp
(
trans
)
ds
=
ds
.
map
(
input_columns
=
"image"
,
operations
=
compose
(),
num_parallel_workers
=
8
,
python_multiprocessing
=
True
)
# apply batch operations
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
# apply dataset repeat operation
ds
=
ds
.
repeat
(
repeat_num
)
return
ds
example/resnet50_quant/train.py
浏览文件 @
da9530f7
...
...
@@ -27,7 +27,7 @@ from mindspore.communication.management import init
import
mindspore.nn
as
nn
import
mindspore.common.initializer
as
weight_init
from
models.resnet_quant
import
resnet50_quant
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
_py
from
src.lr_generator
import
get_lr
from
src.config
import
config
from
src.crossentropy
import
CrossEntropy
...
...
@@ -85,8 +85,8 @@ if __name__ == '__main__':
loss
=
CrossEntropy
(
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
if
args_opt
.
do_train
:
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
epoch_size
,
batch_size
=
config
.
batch_size
,
target
=
target
)
dataset
=
create_dataset
_py
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
epoch_size
,
batch_size
=
config
.
batch_size
,
target
=
target
)
step_size
=
dataset
.
get_dataset_size
()
loss_scale
=
FixedLossScaleManager
(
config
.
loss_scale
,
drop_overflow_update
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录