Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d5adfa52
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d5adfa52
编写于
4月 29, 2020
作者:
C
chujinjin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add accuracy for resnet cifar
上级
3dd369ce
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
198 addition
and
0 deletion
+198
-0
tests/st/tbe_networks/test_resnet_cifar_1p.py
tests/st/tbe_networks/test_resnet_cifar_1p.py
+198
-0
未找到文件。
tests/st/tbe_networks/test_resnet_cifar_1p.py
0 → 100644
浏览文件 @
d5adfa52
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.train.model
import
Model
from
mindspore
import
context
import
mindspore.common.dtype
as
mstype
import
os
import
numpy
as
np
import
mindspore.ops.functional
as
F
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
Callback
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
resnet
import
resnet50
import
random
import
time
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
ds
.
config
.
set_seed
(
1
)
data_home
=
"/home/workspace/mindspore_dataset"
def
create_dataset
(
repeat_num
=
1
,
training
=
True
,
batch_size
=
32
):
data_dir
=
data_home
+
"/cifar-10-batches-bin"
if
not
training
:
data_dir
=
data_home
+
"/cifar-10-verify-bin"
data_set
=
ds
.
Cifar10Dataset
(
data_dir
)
resize_height
=
224
resize_width
=
224
rescale
=
1.0
/
255.0
shift
=
0.0
# define map operations
random_crop_op
=
vision
.
RandomCrop
(
(
32
,
32
),
(
4
,
4
,
4
,
4
))
# padding_mode default CONSTANT
random_horizontal_op
=
vision
.
RandomHorizontalFlip
()
# interpolation default BILINEAR
resize_op
=
vision
.
Resize
((
resize_height
,
resize_width
))
rescale_op
=
vision
.
Rescale
(
rescale
,
shift
)
normalize_op
=
vision
.
Normalize
(
(
0.4465
,
0.4822
,
0.4914
),
(
0.2010
,
0.1994
,
0.2023
))
changeswap_op
=
vision
.
HWC2CHW
()
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
c_trans
=
[]
if
training
:
c_trans
=
[
random_crop_op
,
random_horizontal_op
]
c_trans
+=
[
resize_op
,
rescale_op
,
normalize_op
,
changeswap_op
]
# apply map operations on images
data_set
=
data_set
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
)
data_set
=
data_set
.
map
(
input_columns
=
"image"
,
operations
=
c_trans
)
# apply shuffle operations
data_set
=
data_set
.
shuffle
(
buffer_size
=
1000
)
# apply batch operations
data_set
=
data_set
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
True
)
# apply repeat operations
data_set
=
data_set
.
repeat
(
repeat_num
)
return
data_set
class
CrossEntropyLoss
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
CrossEntropyLoss
,
self
).
__init__
()
self
.
cross_entropy
=
P
.
SoftmaxCrossEntropyWithLogits
()
self
.
mean
=
P
.
ReduceMean
()
self
.
one_hot
=
P
.
OneHot
()
self
.
one
=
Tensor
(
1.0
,
mstype
.
float32
)
self
.
zero
=
Tensor
(
0.0
,
mstype
.
float32
)
def
construct
(
self
,
logits
,
label
):
label
=
self
.
one_hot
(
label
,
F
.
shape
(
logits
)[
1
],
self
.
one
,
self
.
zero
)
loss
=
self
.
cross_entropy
(
logits
,
label
)[
0
]
loss
=
self
.
mean
(
loss
,
(
-
1
,))
return
loss
class
LossGet
(
Callback
):
def
__init__
(
self
,
per_print_times
=
1
):
super
(
LossGet
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"print_step must be int and >= 0."
)
self
.
_per_print_times
=
per_print_times
self
.
_loss
=
0.0
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
loss
=
cb_params
.
net_outputs
if
isinstance
(
loss
,
(
tuple
,
list
)):
if
isinstance
(
loss
[
0
],
Tensor
)
and
isinstance
(
loss
[
0
].
asnumpy
(),
np
.
ndarray
):
loss
=
loss
[
0
]
if
isinstance
(
loss
,
Tensor
)
and
isinstance
(
loss
.
asnumpy
(),
np
.
ndarray
):
loss
=
np
.
mean
(
loss
.
asnumpy
())
cur_step_in_epoch
=
(
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
+
1
if
isinstance
(
loss
,
float
)
and
(
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
)):
raise
ValueError
(
"epoch: {} step: {}. Invalid loss, terminating training."
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
))
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
self
.
_loss
=
loss
print
(
"epoch: %s step: %s, loss is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
loss
))
def
get_loss
(
self
):
return
self
.
_loss
def
train_process
(
device_id
,
epoch_size
,
num_classes
,
device_num
,
batch_size
):
os
.
system
(
"mkdir "
+
str
(
device_id
))
os
.
chdir
(
str
(
device_id
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
enable_task_sink
=
True
,
device_id
=
device_id
)
context
.
set_context
(
enable_loop_sink
=
True
)
context
.
set_context
(
enable_mem_reuse
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
resnet50
(
batch_size
,
num_classes
)
loss
=
CrossEntropyLoss
()
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.01
,
0.9
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
dataset
=
create_dataset
(
epoch_size
,
training
=
True
,
batch_size
=
batch_size
)
batch_num
=
dataset
.
get_dataset_size
()
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batch_num
,
keep_checkpoint_max
=
1
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"train_resnet_cifar10_device_id_"
+
str
(
device_id
),
directory
=
"./"
,
config
=
config_ck
)
loss_cb
=
LossGet
()
model
.
train
(
epoch_size
,
dataset
,
callbacks
=
[
ckpoint_cb
,
loss_cb
])
def
eval
(
batch_size
,
num_classes
):
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
enable_task_sink
=
True
,
device_id
=
0
)
context
.
set_context
(
enable_loop_sink
=
True
)
context
.
set_context
(
enable_mem_reuse
=
True
)
net
=
resnet50
(
batch_size
,
num_classes
)
loss
=
CrossEntropyLoss
()
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.01
,
0.9
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
checkpoint_path
=
"./train_resnet_cifar10_device_id_0-1_1562.ckpt"
param_dict
=
load_checkpoint
(
checkpoint_path
)
load_param_into_net
(
net
,
param_dict
)
net
.
set_train
(
False
)
eval_dataset
=
create_dataset
(
1
,
training
=
False
)
res
=
model
.
eval
(
eval_dataset
)
print
(
"result: "
,
res
)
return
res
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_resnet_cifar_1p
():
device_num
=
1
epoch_size
=
1
num_classes
=
10
batch_size
=
32
device_id
=
0
train_process
(
device_id
,
epoch_size
,
num_classes
,
device_num
,
batch_size
)
time
.
sleep
(
3
)
acc
=
eval
(
batch_size
,
num_classes
)
os
.
chdir
(
"../"
)
os
.
system
(
"rm -rf "
+
str
(
device_id
))
print
(
"End training..."
)
assert
(
acc
[
'acc'
]
>
0.35
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录