Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
14241786
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
14241786
编写于
6月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1663 change lenet and alexnet dir
Merge pull request !1663 from wukesong/change_network_path
上级
b7b4333d
7dfd3699
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
148 addition
and
70 deletion
+148
-70
model_zoo/alexnet/README.md
model_zoo/alexnet/README.md
+0
-0
model_zoo/alexnet/alexnet.py
model_zoo/alexnet/alexnet.py
+2
-3
model_zoo/alexnet/config.py
model_zoo/alexnet/config.py
+0
-0
model_zoo/alexnet/dataset.py
model_zoo/alexnet/dataset.py
+1
-1
model_zoo/alexnet/eval.py
model_zoo/alexnet/eval.py
+5
-6
model_zoo/alexnet/generator_lr.py
model_zoo/alexnet/generator_lr.py
+0
-0
model_zoo/alexnet/train.py
model_zoo/alexnet/train.py
+5
-5
model_zoo/lenet/README.md
model_zoo/lenet/README.md
+0
-0
model_zoo/lenet/config.py
model_zoo/lenet/config.py
+0
-0
model_zoo/lenet/dataset.py
model_zoo/lenet/dataset.py
+0
-0
model_zoo/lenet/eval.py
model_zoo/lenet/eval.py
+1
-1
model_zoo/lenet/lenet.py
model_zoo/lenet/lenet.py
+2
-3
model_zoo/lenet/train.py
model_zoo/lenet/train.py
+2
-2
tests/perf_test/lenet.py
tests/perf_test/lenet.py
+78
-0
tests/perf_test/test_lenet.py
tests/perf_test/test_lenet.py
+2
-2
tests/st/networks/models/lenet.py
tests/st/networks/models/lenet.py
+0
-46
tests/st/networks/test_gpu_lenet.py
tests/st/networks/test_gpu_lenet.py
+50
-1
未找到文件。
example/alexnet_cifar10
/README.md
→
model_zoo/alexnet
/README.md
浏览文件 @
14241786
文件已移动
m
indspore/model_zoo
/alexnet.py
→
m
odel_zoo/alexnet
/alexnet.py
浏览文件 @
14241786
...
@@ -36,10 +36,9 @@ class AlexNet(nn.Cell):
...
@@ -36,10 +36,9 @@ class AlexNet(nn.Cell):
"""
"""
Alexnet
Alexnet
"""
"""
def
__init__
(
self
,
num_classes
=
10
):
def
__init__
(
self
,
num_classes
=
10
,
channel
=
3
):
super
(
AlexNet
,
self
).
__init__
()
super
(
AlexNet
,
self
).
__init__
()
self
.
batch_size
=
32
self
.
conv1
=
conv
(
channel
,
96
,
11
,
stride
=
4
)
self
.
conv1
=
conv
(
3
,
96
,
11
,
stride
=
4
)
self
.
conv2
=
conv
(
96
,
256
,
5
,
pad_mode
=
"same"
)
self
.
conv2
=
conv
(
96
,
256
,
5
,
pad_mode
=
"same"
)
self
.
conv3
=
conv
(
256
,
384
,
3
,
pad_mode
=
"same"
)
self
.
conv3
=
conv
(
256
,
384
,
3
,
pad_mode
=
"same"
)
self
.
conv4
=
conv
(
384
,
384
,
3
,
pad_mode
=
"same"
)
self
.
conv4
=
conv
(
384
,
384
,
3
,
pad_mode
=
"same"
)
...
...
example/alexnet_cifar10
/config.py
→
model_zoo/alexnet
/config.py
浏览文件 @
14241786
文件已移动
example/alexnet_cifar10
/dataset.py
→
model_zoo/alexnet
/dataset.py
浏览文件 @
14241786
...
@@ -23,7 +23,7 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
...
@@ -23,7 +23,7 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
def
create_dataset
(
data_path
,
batch_size
=
32
,
repeat_size
=
1
,
status
=
"train"
):
def
create_dataset
_mnist
(
data_path
,
batch_size
=
32
,
repeat_size
=
1
,
status
=
"train"
):
"""
"""
create dataset for train or test
create dataset for train or test
"""
"""
...
...
example/alexnet_cifar10
/eval.py
→
model_zoo/alexnet
/eval.py
浏览文件 @
14241786
...
@@ -20,10 +20,10 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
...
@@ -20,10 +20,10 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import
argparse
import
argparse
from
config
import
alexnet_cfg
as
cfg
from
config
import
alexnet_cfg
as
cfg
from
dataset
import
create_dataset
from
dataset
import
create_dataset_mnist
from
alexnet
import
AlexNet
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.model_zoo.alexnet
import
AlexNet
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
...
@@ -50,9 +50,8 @@ if __name__ == "__main__":
...
@@ -50,9 +50,8 @@ if __name__ == "__main__":
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
create_dataset
(
args
.
data_path
,
ds_eval
=
create_dataset_mnist
(
args
.
data_path
,
cfg
.
batch_size
,
cfg
.
batch_size
,
1
,
status
=
"test"
)
"test"
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
args
.
dataset_sink_mode
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== Accuracy:{} =============="
.
format
(
acc
))
print
(
"============== Accuracy:{} =============="
.
format
(
acc
))
example/alexnet_cifar10
/generator_lr.py
→
model_zoo/alexnet
/generator_lr.py
浏览文件 @
14241786
文件已移动
example/alexnet_cifar10
/train.py
→
model_zoo/alexnet
/train.py
浏览文件 @
14241786
...
@@ -20,14 +20,14 @@ python train.py --data_path /YourDataPath
...
@@ -20,14 +20,14 @@ python train.py --data_path /YourDataPath
import
argparse
import
argparse
from
config
import
alexnet_cfg
as
cfg
from
config
import
alexnet_cfg
as
cfg
from
dataset
import
create_dataset
from
dataset
import
create_dataset
_mnist
from
generator_lr
import
get_lr
from
generator_lr
import
get_lr
from
alexnet
import
AlexNet
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.model_zoo.alexnet
import
AlexNet
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
...
@@ -50,9 +50,9 @@ if __name__ == "__main__":
...
@@ -50,9 +50,9 @@ if __name__ == "__main__":
model
=
Model
(
network
,
loss
,
opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# test
model
=
Model
(
network
,
loss
,
opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# test
print
(
"============== Starting Training =============="
)
print
(
"============== Starting Training =============="
)
ds_train
=
create_dataset
(
args
.
data_path
,
ds_train
=
create_dataset
_mnist
(
args
.
data_path
,
cfg
.
batch_size
,
cfg
.
batch_size
,
cfg
.
epoch_size
)
cfg
.
epoch_size
)
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
...
...
example/lenet_mnis
t/README.md
→
model_zoo/lene
t/README.md
浏览文件 @
14241786
文件已移动
example/lenet_mnis
t/config.py
→
model_zoo/lene
t/config.py
浏览文件 @
14241786
文件已移动
example/lenet_mnis
t/dataset.py
→
model_zoo/lene
t/dataset.py
浏览文件 @
14241786
文件已移动
example/lenet_mnis
t/eval.py
→
model_zoo/lene
t/eval.py
浏览文件 @
14241786
...
@@ -22,8 +22,8 @@ import os
...
@@ -22,8 +22,8 @@ import os
import
argparse
import
argparse
from
dataset
import
create_dataset
from
dataset
import
create_dataset
from
config
import
mnist_cfg
as
cfg
from
config
import
mnist_cfg
as
cfg
from
lenet
import
LeNet5
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
...
...
m
indspore/model_zoo
/lenet.py
→
m
odel_zoo/lenet
/lenet.py
浏览文件 @
14241786
...
@@ -50,11 +50,10 @@ class LeNet5(nn.Cell):
...
@@ -50,11 +50,10 @@ class LeNet5(nn.Cell):
>>> LeNet(num_class=10)
>>> LeNet(num_class=10)
"""
"""
def
__init__
(
self
,
num_class
=
10
):
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
num_class
=
num_class
self
.
batch_size
=
32
self
.
conv1
=
conv
(
channel
,
6
,
5
)
self
.
conv1
=
conv
(
1
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
...
...
example/lenet_mnis
t/train.py
→
model_zoo/lene
t/train.py
浏览文件 @
14241786
...
@@ -22,8 +22,8 @@ import os
...
@@ -22,8 +22,8 @@ import os
import
argparse
import
argparse
from
config
import
mnist_cfg
as
cfg
from
config
import
mnist_cfg
as
cfg
from
dataset
import
create_dataset
from
dataset
import
create_dataset
from
lenet
import
LeNet5
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train
import
Model
from
mindspore.train
import
Model
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
help
=
'device where the code will be implemented (default: Ascend)'
)
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
Fals
e
,
help
=
'dataset_sink_mode is False or True'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
Tru
e
,
help
=
'dataset_sink_mode is False or True'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
tests/perf_test/lenet.py
0 → 100644
浏览文件 @
14241786
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import
mindspore.nn
as
nn
from
mindspore.common.initializer
import
TruncatedNormal
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
"""weight initial for conv layer"""
weight
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"valid"
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
"""weight initial for fc layer"""
weight
=
weight_variable
()
bias
=
weight_variable
()
return
nn
.
Dense
(
input_channels
,
out_channels
,
weight
,
bias
)
def
weight_variable
():
"""weight initial"""
return
TruncatedNormal
(
0.02
)
class
LeNet5
(
nn
.
Cell
):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
conv
(
channel
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
tests/perf_test/test_lenet.py
浏览文件 @
14241786
...
@@ -17,12 +17,12 @@
...
@@ -17,12 +17,12 @@
import
numpy
as
np
import
numpy
as
np
from
lenet
import
LeNet5
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.ops.composite
as
C
import
mindspore.ops.composite
as
C
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.common.api
import
_executor
from
mindspore.common.api
import
_executor
from
mindspore.model_zoo.lenet
import
LeNet
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
@@ -61,7 +61,7 @@ def test_compile():
...
@@ -61,7 +61,7 @@ def test_compile():
def
test_compile_grad
():
def
test_compile_grad
():
"""Compile forward and backward graph"""
"""Compile forward and backward graph"""
net
=
LeNet
(
num_class
=
num_class
)
net
=
LeNet
5
(
num_class
=
num_class
)
inp
=
Tensor
(
np
.
array
(
np
.
random
.
randn
(
batch_size
,
inp
=
Tensor
(
np
.
array
(
np
.
random
.
randn
(
batch_size
,
channel
,
channel
,
height
,
height
,
...
...
tests/st/networks/models/lenet.py
已删除
100644 → 0
浏览文件 @
b7b4333d
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
class
LeNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
self
.
relu
=
P
.
ReLU
()
self
.
batch_size
=
32
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
has_bias
=
False
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
kernel_size
=
5
,
stride
=
1
,
padding
=
0
,
has_bias
=
False
,
pad_mode
=
'valid'
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
reshape
=
P
.
Reshape
()
self
.
fc1
=
nn
.
Dense
(
400
,
120
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc3
=
nn
.
Dense
(
84
,
10
)
def
construct
(
self
,
input_x
):
output
=
self
.
conv1
(
input_x
)
output
=
self
.
relu
(
output
)
output
=
self
.
pool
(
output
)
output
=
self
.
conv2
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
pool
(
output
)
output
=
self
.
reshape
(
output
,
(
self
.
batch_size
,
-
1
))
output
=
self
.
fc1
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
fc2
(
output
)
output
=
self
.
relu
(
output
)
output
=
self
.
fc3
(
output
)
return
output
tests/st/networks/test_gpu_lenet.py
浏览文件 @
14241786
...
@@ -26,17 +26,66 @@ import mindspore.nn as nn
...
@@ -26,17 +26,66 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore.nn
import
Dense
,
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
Dense
,
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn.optim
import
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.train.callback
import
LossMonitor
from
mindspore.train.callback
import
LossMonitor
from
mindspore.common.initializer
import
TruncatedNormal
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
"""weight initial for conv layer"""
weight
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"valid"
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
"""weight initial for fc layer"""
weight
=
weight_variable
()
bias
=
weight_variable
()
return
nn
.
Dense
(
input_channels
,
out_channels
,
weight
,
bias
)
def
weight_variable
():
"""weight initial"""
return
TruncatedNormal
(
0.02
)
class
LeNet5
(
nn
.
Cell
):
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
conv
(
channel
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
class
LeNet
(
nn
.
Cell
):
class
LeNet
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
super
(
LeNet
,
self
).
__init__
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录