提交 b36094e3 编写于 作者: C caojian05

remove the parameter batch_size of VGG16, for we can use flatten instead of reshape.

上级 ebd0fd33
......@@ -39,7 +39,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
net = vgg16(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
......
......@@ -64,7 +64,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
net = vgg16(num_classes=cfg.num_classes)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
......
......@@ -14,7 +14,6 @@
# ============================================================================
"""VGG."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype
......@@ -63,8 +62,7 @@ class Vgg(nn.Cell):
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
super(Vgg, self).__init__()
self.layers = _make_layer(base, batch_norm=batch_norm)
self.reshape = P.Reshape()
self.shp = (batch_size, -1)
self.flatten = nn.Flatten()
self.classifier = nn.SequentialCell([
nn.Dense(512 * 7 * 7, 4096),
nn.ReLU(),
......@@ -74,7 +72,7 @@ class Vgg(nn.Cell):
def construct(self, x):
x = self.layers(x)
x = self.reshape(x, self.shp)
x = self.flatten(x)
x = self.classifier(x)
return x
......@@ -87,20 +85,19 @@ cfg = {
}
def vgg16(batch_size=1, num_classes=1000):
def vgg16(num_classes=1000):
"""
Get Vgg16 neural network with batch normalization.
Args:
batch_size (int): Batch size. Default: 1.
num_classes (int): Class numbers. Default: 1000.
Returns:
Cell, cell instance of Vgg16 neural network with batch normalization.
Examples:
>>> vgg16(batch_size=1, num_classes=1000)
>>> vgg16(num_classes=1000)
"""
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size)
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True)
return net
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册