提交 1bbcd81f 编写于 作者: T typhoonzero

fix

上级 8c4b45c7
......@@ -33,10 +33,12 @@ def parse_args():
parser.add_argument(
'--model',
type=str,
default='resnet_dist',
default='DistResNet',
help='The model to run.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size per device.')
parser.add_argument(
'--multi_batch_repeat', type=int, default=1, help='Batch merge repeats.')
parser.add_argument(
'--learning_rate', type=float, default=0.1, help='The learning rate.')
parser.add_argument(
......@@ -124,7 +126,7 @@ def get_model(args, is_train, main_prog, startup_prog):
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
input, label = fluid.layers.read_file(pyreader)
model_def = models.__dict__[args.model](is_train)
model_def = models.__dict__[args.model](layers=50, is_train=is_train)
predict = model_def.net(input, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label)
......
......@@ -3,6 +3,8 @@ from .mobilenet import MobileNet
from .googlenet import GoogleNet
from .vgg import VGG11, VGG13, VGG16, VGG19
from .resnet import ResNet50, ResNet101, ResNet152
from .resnet_dist import DistResNet
from .inception_v4 import InceptionV4
from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_32x4d
from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131
import learning_rate
......@@ -5,7 +5,7 @@ import paddle
import paddle.fluid as fluid
import math
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
__all__ = ["DistResNet"]
train_parameters = {
"input_size": [3, 224, 224],
......@@ -20,7 +20,7 @@ train_parameters = {
}
class ResNet():
class DistResNet():
def __init__(self, layers=50, is_train=True):
self.params = train_parameters
self.layers = layers
......@@ -119,18 +119,3 @@ class ResNet():
short = self.shortcut(input, num_filters * 4, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNet50():
model = ResNet(layers=50)
return model
def ResNet101():
model = ResNet(layers=101)
return model
def ResNet152():
model = ResNet(layers=152)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册