Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
66df8437
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
66df8437
编写于
4月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!9 Update mnist_lenet5 example
Merge pull request !9 from pkuliuliu/master
上级
10cc003b
1765a2a6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
21 addition
and
38 deletion
+21
-38
example/data_processing.py
example/data_processing.py
+1
-1
example/mnist_demo/lenet5_net.py
example/mnist_demo/lenet5_net.py
+4
-5
example/mnist_demo/mnist_train.py
example/mnist_demo/mnist_train.py
+16
-32
未找到文件。
example/data_processing.py
浏览文件 @
66df8437
...
...
@@ -37,10 +37,10 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
rescale_op
=
CV
.
Rescale
(
rescale
,
shift
)
hwc2chw_op
=
CV
.
HWC2CHW
()
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
one_hot_enco
=
C
.
OneHot
(
10
)
# apply map operations on images
if
not
sparse
:
one_hot_enco
=
C
.
OneHot
(
10
)
ds1
=
ds1
.
map
(
input_columns
=
"label"
,
operations
=
one_hot_enco
,
num_parallel_workers
=
num_parallel_workers
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
float32
)
...
...
example/mnist_demo/lenet5_net.py
浏览文件 @
66df8437
...
...
@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
mindspore.nn
as
nn
import
mindspore.ops.operations
as
P
from
mindspore
import
nn
from
mindspore.common.initializer
import
TruncatedNormal
...
...
@@ -30,7 +29,7 @@ def fc_with_initialize(input_channels, out_channels):
def
weight_variable
():
return
TruncatedNormal
(
0.2
)
return
TruncatedNormal
(
0.
0
2
)
class
LeNet5
(
nn
.
Cell
):
...
...
@@ -46,7 +45,7 @@ class LeNet5(nn.Cell):
self
.
fc3
=
fc_with_initialize
(
84
,
10
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
reshape
=
P
.
Reshape
()
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
...
...
@@ -55,7 +54,7 @@ class LeNet5(nn.Cell):
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
reshape
(
x
,
(
-
1
,
16
*
5
*
5
)
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
...
...
example/mnist_demo/mnist_train.py
浏览文件 @
66df8437
...
...
@@ -20,10 +20,7 @@ from mindspore import context, Tensor
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train
import
Model
import
mindspore.ops.operations
as
P
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.ops
import
functional
as
F
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.utils.logger
import
LogUtil
...
...
@@ -32,26 +29,7 @@ from lenet5_net import LeNet5
sys
.
path
.
append
(
".."
)
from
data_processing
import
generate_mnist_dataset
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'Lenet5_train'
class
CrossEntropyLoss
(
nn
.
Cell
):
"""
Define loss for network
"""
def
__init__
(
self
):
super
(
CrossEntropyLoss
,
self
).
__init__
()
self
.
cross_entropy
=
P
.
SoftmaxCrossEntropyWithLogits
()
self
.
mean
=
P
.
ReduceMean
()
self
.
one_hot
=
P
.
OneHot
()
self
.
on_value
=
Tensor
(
1.0
,
mstype
.
float32
)
self
.
off_value
=
Tensor
(
0.0
,
mstype
.
float32
)
def
construct
(
self
,
logits
,
label
):
label
=
self
.
one_hot
(
label
,
F
.
shape
(
logits
)[
1
],
self
.
on_value
,
self
.
off_value
)
loss
=
self
.
cross_entropy
(
logits
,
label
)[
0
]
loss
=
self
.
mean
(
loss
,
(
-
1
,))
return
loss
TAG
=
"Lenet5_train"
def
mnist_train
(
epoch_size
,
batch_size
,
lr
,
momentum
):
...
...
@@ -66,23 +44,29 @@ def mnist_train(epoch_size, batch_size, lr, momentum):
batch_size
=
batch_size
,
repeat_size
=
1
)
network
=
LeNet5
()
net
work
.
set_train
()
net_loss
=
CrossEntropyLoss
(
)
net
_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
lr
,
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
1875
,
keep_checkpoint_max
=
10
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
'./trained_ckpt_file/'
,
config
=
config_ck
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
1875
,
keep_checkpoint_max
=
10
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
"./trained_ckpt_file/"
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
model
.
train
(
epoch_size
,
ds
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
()],
dataset_sink_mode
=
False
)
# train
model
.
train
(
epoch_size
,
ds
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
()],
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
param_dict
=
load_checkpoint
(
"trained_ckpt_file/checkpoint_lenet-10_1875.ckpt"
)
ckpt_file_name
=
"trained_ckpt_file/checkpoint_lenet-10_1875.ckpt"
param_dict
=
load_checkpoint
(
ckpt_file_name
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
mnist_path
,
"test"
),
batch_size
=
batch_size
)
acc
=
model
.
eval
(
ds_eval
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
mnist_path
,
"test"
),
batch_size
=
batch_size
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
if
__name__
==
'__main__'
:
mnist_train
(
10
,
32
,
0.0
0
1
,
0.9
)
mnist_train
(
10
,
32
,
0.01
,
0.9
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录