Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
b36094e3
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b36094e3
编写于
4月 23, 2020
作者:
C
caojian05
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove the parameter batch_size of VGG16, for we can use flatten instead of reshape.
上级
ebd0fd33
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
7 addition
and
10 deletion
+7
-10
example/vgg16_cifar10/eval.py
example/vgg16_cifar10/eval.py
+1
-1
example/vgg16_cifar10/train.py
example/vgg16_cifar10/train.py
+1
-1
mindspore/model_zoo/vgg.py
mindspore/model_zoo/vgg.py
+5
-8
未找到文件。
example/vgg16_cifar10/eval.py
浏览文件 @
b36094e3
...
...
@@ -39,7 +39,7 @@ if __name__ == '__main__':
context
.
set_context
(
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
enable_mem_reuse
=
True
,
enable_hccl
=
False
)
net
=
vgg16
(
batch_size
=
cfg
.
batch_size
,
num_classes
=
cfg
.
num_classes
)
net
=
vgg16
(
num_classes
=
cfg
.
num_classes
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.01
,
cfg
.
momentum
,
weight_decay
=
cfg
.
weight_decay
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
...
...
example/vgg16_cifar10/train.py
浏览文件 @
b36094e3
...
...
@@ -64,7 +64,7 @@ if __name__ == '__main__':
context
.
set_context
(
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
enable_mem_reuse
=
True
,
enable_hccl
=
False
)
net
=
vgg16
(
batch_size
=
cfg
.
batch_size
,
num_classes
=
cfg
.
num_classes
)
net
=
vgg16
(
num_classes
=
cfg
.
num_classes
)
lr
=
lr_steps
(
0
,
lr_max
=
cfg
.
lr_init
,
total_epochs
=
cfg
.
epoch_size
,
steps_per_epoch
=
50000
//
cfg
.
batch_size
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
Tensor
(
lr
),
cfg
.
momentum
,
weight_decay
=
cfg
.
weight_decay
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
...
...
mindspore/model_zoo/vgg.py
浏览文件 @
b36094e3
...
...
@@ -14,7 +14,6 @@
# ============================================================================
"""VGG."""
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore.common.initializer
import
initializer
import
mindspore.common.dtype
as
mstype
...
...
@@ -63,8 +62,7 @@ class Vgg(nn.Cell):
def
__init__
(
self
,
base
,
num_classes
=
1000
,
batch_norm
=
False
,
batch_size
=
1
):
super
(
Vgg
,
self
).
__init__
()
self
.
layers
=
_make_layer
(
base
,
batch_norm
=
batch_norm
)
self
.
reshape
=
P
.
Reshape
()
self
.
shp
=
(
batch_size
,
-
1
)
self
.
flatten
=
nn
.
Flatten
()
self
.
classifier
=
nn
.
SequentialCell
([
nn
.
Dense
(
512
*
7
*
7
,
4096
),
nn
.
ReLU
(),
...
...
@@ -74,7 +72,7 @@ class Vgg(nn.Cell):
def
construct
(
self
,
x
):
x
=
self
.
layers
(
x
)
x
=
self
.
reshape
(
x
,
self
.
shp
)
x
=
self
.
flatten
(
x
)
x
=
self
.
classifier
(
x
)
return
x
...
...
@@ -87,20 +85,19 @@ cfg = {
}
def
vgg16
(
batch_size
=
1
,
num_classes
=
1000
):
def
vgg16
(
num_classes
=
1000
):
"""
Get Vgg16 neural network with batch normalization.
Args:
batch_size (int): Batch size. Default: 1.
num_classes (int): Class numbers. Default: 1000.
Returns:
Cell, cell instance of Vgg16 neural network with batch normalization.
Examples:
>>> vgg16(
batch_size=1,
num_classes=1000)
>>> vgg16(num_classes=1000)
"""
net
=
Vgg
(
cfg
[
'16'
],
num_classes
=
num_classes
,
batch_norm
=
True
,
batch_size
=
batch_size
)
net
=
Vgg
(
cfg
[
'16'
],
num_classes
=
num_classes
,
batch_norm
=
True
)
return
net
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录