提交 53b23370 编写于 作者: L LielinJiang

refine api

上级 6c89d516
......@@ -71,14 +71,14 @@ class DatasetFolder(Dataset):
Args:
root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions.
loader (callable|optional): A function to load a sample given its path.
extensions (tuple[str]|optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
transform (callable|optional): A function/transform that takes in
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
target_transform (callable|optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
is_valid_file (callable|optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
......@@ -97,6 +97,8 @@ class DatasetFolder(Dataset):
target_transform=None,
is_valid_file=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
if extensions is None:
extensions = IMG_EXTENSIONS
classes, class_to_idx = self._find_classes(self.root)
......
......@@ -111,13 +111,21 @@ class MobileNetV1(Model):
Args:
scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000.
num_classes (int): output dim of last fc layer. Default: -1.
with_pool (bool): use pool or not. Default: False.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
"""
def __init__(self, scale=1.0, class_dim=1000):
def __init__(self,
scale=1.0,
num_classes=-1,
with_pool=False,
classifier_activation='softmax'):
super(MobileNetV1, self).__init__()
self.scale = scale
self.dwsl = []
self.num_classes = num_classes
self.with_pool = with_pool
self.conv1 = ConvBNLayer(
num_channels=3,
......@@ -227,28 +235,34 @@ class MobileNetV1(Model):
name="conv6")
self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
if with_pool:
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.out = Linear(
int(1024 * scale),
class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
if num_classes > -1:
self.out = Linear(
int(1024 * scale),
num_classes,
act=classifier_activation,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
y = dws(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
if self.with_pool:
y = self.pool2d_avg(y)
if self.num_classes > -1:
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV1(**kwargs)
model = MobileNetV1(num_classes=1000, with_pool=True, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
......@@ -262,5 +276,11 @@ def _mobilenet(arch, pretrained=False, **kwargs):
def mobilenet_v1(pretrained=False, scale=1.0):
"""MobileNetV1
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
"""
model = _mobilenet('mobilenetv1_' + str(scale), pretrained, scale=scale)
return model
......@@ -156,13 +156,20 @@ class MobileNetV2(Model):
Args:
scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000.
num_classes (int): output dim of last fc layer. Default: -1.
with_pool (bool): use pool or not. Default: False.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
"""
def __init__(self, scale=1.0, class_dim=1000):
def __init__(self,
scale=1.0,
num_classes=-1,
with_pool=False,
classifier_activation='softmax'):
super(MobileNetV2, self).__init__()
self.scale = scale
self.class_dim = class_dim
self.num_classes = num_classes
self.with_pool = with_pool
bottleneck_params_list = [
(1, 16, 1, 1),
......@@ -174,7 +181,6 @@ class MobileNetV2(Model):
(6, 320, 1, 1),
]
#1. conv1
self._conv1 = ConvBNLayer(
num_channels=3,
num_filters=int(32 * scale),
......@@ -182,7 +188,6 @@ class MobileNetV2(Model):
stride=2,
padding=1)
#2. bottleneck sequences
self._invl = []
i = 1
in_c = int(32 * scale)
......@@ -196,7 +201,6 @@ class MobileNetV2(Model):
self._invl.append(tmp)
in_c = int(c * scale)
#3. last_conv
self._out_c = int(1280 * scale) if scale > 1.0 else 1280
self._conv9 = ConvBNLayer(
num_channels=in_c,
......@@ -205,31 +209,34 @@ class MobileNetV2(Model):
stride=1,
padding=0)
#4. pool
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
if with_pool:
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
#5. fc
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear(
self._out_c,
class_dim,
act='softmax',
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
if num_classes > 0:
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear(
self._out_c,
num_classes,
act=classifier_activation,
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
y = inv(y)
y = self._conv9(y, if_act=True)
y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
if self.with_pool:
y = self._pool2d_avg(y)
if self.num_classes > 0:
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs)
model = MobileNetV2(num_classes=1000, with_pool=True, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
......@@ -246,7 +253,8 @@ def mobilenet_v2(pretrained=False, scale=1.0):
"""MobileNetV2
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
"""
model = _mobilenet('mobilenetv2_' + str(scale), pretrained, scale=scale)
return model
......@@ -164,11 +164,21 @@ class ResNet(Model):
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer, default: 1000.
with_pool (bool): use pool or not. Default: False.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
"""
def __init__(self, Block, depth=50, num_classes=1000):
def __init__(self,
Block,
depth=50,
num_classes=-1,
with_pool=False,
classifier_activation='softmax'):
super(ResNet, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
layer_config = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
......@@ -212,31 +222,37 @@ class ResNet(Model):
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
if with_pool:
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
if num_classes > 0:
stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act=classifier_activation,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
if self.with_pool:
x = self.global_pool(x)
if self.num_classes > -1:
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained):
model = ResNet(Block, depth)
model = ResNet(Block, depth, num_classes=1000, with_pool=True)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
......
......@@ -23,12 +23,8 @@ from .download import get_weights_path
__all__ = [
'VGG',
'vgg11',
'vgg11_bn',
'vgg13',
'vgg13_bn',
'vgg16',
'vgg16_bn',
'vgg19_bn',
'vgg19',
]
......@@ -39,11 +35,11 @@ model_urls = {
class Classifier(fluid.dygraph.Layer):
def __init__(self, num_classes):
def __init__(self, num_classes, classifier_activation='softmax'):
super(Classifier, self).__init__()
self.linear1 = Linear(512 * 7 * 7, 4096)
self.linear2 = Linear(4096, 4096)
self.linear3 = Linear(4096, num_classes, act='softmax')
self.linear3 = Linear(4096, num_classes, act=classifier_activation)
def forward(self, x):
x = self.linear1(x)
......@@ -62,20 +58,29 @@ class VGG(Model):
Args:
features (fluid.dygraph.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. Default: 1000.
num_classes (int): output dim of last fc layer. Default: -1.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
"""
def __init__(self, features, num_classes=1000):
def __init__(self,
features,
num_classes=-1,
classifier_activation='softmax'):
super(VGG, self).__init__()
self.features = features
classifier = Classifier(num_classes)
self.classifier = self.add_sublayer("classifier",
Sequential(classifier))
self.num_classes = num_classes
if num_classes > 0:
classifier = Classifier(num_classes, classifier_activation)
self.classifier = self.add_sublayer("classifier",
Sequential(classifier))
def forward(self, x):
x = self.features(x)
x = fluid.layers.flatten(x, 1)
x = self.classifier(x)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.classifier(x)
return x
......@@ -114,7 +119,10 @@ cfgs = {
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
model = VGG(make_layers(
cfgs[cfg], batch_norm=batch_norm),
num_classes=1000,
**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
......@@ -128,73 +136,53 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
return model
def vgg11(pretrained=False, **kwargs):
def vgg11(pretrained=False, batch_norm=False):
"""VGG 11-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
"""
return _vgg('vgg11', 'A', False, pretrained, **kwargs)
model_name = 'vgg11'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'A', batch_norm, pretrained)
def vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg11_bn', 'A', True, pretrained, **kwargs)
def vgg13(pretrained=False, **kwargs):
def vgg13(pretrained=False, batch_norm=False):
"""VGG 13-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg13', 'B', False, pretrained, **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
"""
return _vgg('vgg13_bn', 'B', True, pretrained, **kwargs)
model_name = 'vgg13'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'B', batch_norm, pretrained)
def vgg16(pretrained=False, **kwargs):
def vgg16(pretrained=False, batch_norm=False):
"""VGG 16-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg16', 'D', False, pretrained, **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
"""
return _vgg('vgg16_bn', 'D', True, pretrained, **kwargs)
model_name = 'vgg16'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'D', batch_norm, pretrained)
def vgg19(pretrained=False, **kwargs):
def vgg19(pretrained=False, batch_norm=False):
"""VGG 19-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg19', 'E', False, pretrained, **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
"""
return _vgg('vgg19_bn', 'E', True, pretrained, **kwargs)
model_name = 'vgg19'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'E', batch_norm, pretrained)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册