提交 53b23370 编写于 作者: L LielinJiang

refine api

上级 6c89d516
...@@ -71,14 +71,14 @@ class DatasetFolder(Dataset): ...@@ -71,14 +71,14 @@ class DatasetFolder(Dataset):
Args: Args:
root (string): Root directory path. root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path. loader (callable|optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions. extensions (tuple[str]|optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed. both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in transform (callable|optional): A function/transform that takes in
a sample and returns a transformed version. a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes target_transform (callable|optional): A function/transform that takes
in the target and transforms it. in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file is_valid_file (callable|optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files) and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed. both extensions and is_valid_file should not be passed.
...@@ -97,6 +97,8 @@ class DatasetFolder(Dataset): ...@@ -97,6 +97,8 @@ class DatasetFolder(Dataset):
target_transform=None, target_transform=None,
is_valid_file=None): is_valid_file=None):
self.root = root self.root = root
self.transform = transform
self.target_transform = target_transform
if extensions is None: if extensions is None:
extensions = IMG_EXTENSIONS extensions = IMG_EXTENSIONS
classes, class_to_idx = self._find_classes(self.root) classes, class_to_idx = self._find_classes(self.root)
......
...@@ -111,13 +111,21 @@ class MobileNetV1(Model): ...@@ -111,13 +111,21 @@ class MobileNetV1(Model):
Args: Args:
scale (float): scale of channels in each layer. Default: 1.0. scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000. num_classes (int): output dim of last fc layer. Default: -1.
with_pool (bool): use pool or not. Default: False.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
""" """
def __init__(self, scale=1.0, class_dim=1000): def __init__(self,
scale=1.0,
num_classes=-1,
with_pool=False,
classifier_activation='softmax'):
super(MobileNetV1, self).__init__() super(MobileNetV1, self).__init__()
self.scale = scale self.scale = scale
self.dwsl = [] self.dwsl = []
self.num_classes = num_classes
self.with_pool = with_pool
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=3, num_channels=3,
...@@ -227,28 +235,34 @@ class MobileNetV1(Model): ...@@ -227,28 +235,34 @@ class MobileNetV1(Model):
name="conv6") name="conv6")
self.dwsl.append(dws6) self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) if with_pool:
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.out = Linear( if num_classes > -1:
int(1024 * scale), self.out = Linear(
class_dim, int(1024 * scale),
act='softmax', num_classes,
param_attr=ParamAttr( act=classifier_activation,
initializer=MSRA(), name=self.full_name() + "fc7_weights"), param_attr=ParamAttr(
bias_attr=ParamAttr(name="fc7_offset")) initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs): def forward(self, inputs):
y = self.conv1(inputs) y = self.conv1(inputs)
for dws in self.dwsl: for dws in self.dwsl:
y = dws(y) y = dws(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 1024]) if self.with_pool:
y = self.out(y) y = self.pool2d_avg(y)
if self.num_classes > -1:
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y return y
def _mobilenet(arch, pretrained=False, **kwargs): def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV1(**kwargs) model = MobileNetV1(num_classes=1000, with_pool=True, **kwargs)
if pretrained: if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch) arch)
...@@ -262,5 +276,11 @@ def _mobilenet(arch, pretrained=False, **kwargs): ...@@ -262,5 +276,11 @@ def _mobilenet(arch, pretrained=False, **kwargs):
def mobilenet_v1(pretrained=False, scale=1.0): def mobilenet_v1(pretrained=False, scale=1.0):
"""MobileNetV1
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
"""
model = _mobilenet('mobilenetv1_' + str(scale), pretrained, scale=scale) model = _mobilenet('mobilenetv1_' + str(scale), pretrained, scale=scale)
return model return model
...@@ -156,13 +156,20 @@ class MobileNetV2(Model): ...@@ -156,13 +156,20 @@ class MobileNetV2(Model):
Args: Args:
scale (float): scale of channels in each layer. Default: 1.0. scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000. num_classes (int): output dim of last fc layer. Default: -1.
with_pool (bool): use pool or not. Default: False.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
""" """
def __init__(self, scale=1.0, class_dim=1000): def __init__(self,
scale=1.0,
num_classes=-1,
with_pool=False,
classifier_activation='softmax'):
super(MobileNetV2, self).__init__() super(MobileNetV2, self).__init__()
self.scale = scale self.scale = scale
self.class_dim = class_dim self.num_classes = num_classes
self.with_pool = with_pool
bottleneck_params_list = [ bottleneck_params_list = [
(1, 16, 1, 1), (1, 16, 1, 1),
...@@ -174,7 +181,6 @@ class MobileNetV2(Model): ...@@ -174,7 +181,6 @@ class MobileNetV2(Model):
(6, 320, 1, 1), (6, 320, 1, 1),
] ]
#1. conv1
self._conv1 = ConvBNLayer( self._conv1 = ConvBNLayer(
num_channels=3, num_channels=3,
num_filters=int(32 * scale), num_filters=int(32 * scale),
...@@ -182,7 +188,6 @@ class MobileNetV2(Model): ...@@ -182,7 +188,6 @@ class MobileNetV2(Model):
stride=2, stride=2,
padding=1) padding=1)
#2. bottleneck sequences
self._invl = [] self._invl = []
i = 1 i = 1
in_c = int(32 * scale) in_c = int(32 * scale)
...@@ -196,7 +201,6 @@ class MobileNetV2(Model): ...@@ -196,7 +201,6 @@ class MobileNetV2(Model):
self._invl.append(tmp) self._invl.append(tmp)
in_c = int(c * scale) in_c = int(c * scale)
#3. last_conv
self._out_c = int(1280 * scale) if scale > 1.0 else 1280 self._out_c = int(1280 * scale) if scale > 1.0 else 1280
self._conv9 = ConvBNLayer( self._conv9 = ConvBNLayer(
num_channels=in_c, num_channels=in_c,
...@@ -205,31 +209,34 @@ class MobileNetV2(Model): ...@@ -205,31 +209,34 @@ class MobileNetV2(Model):
stride=1, stride=1,
padding=0) padding=0)
#4. pool if with_pool:
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
#5. fc if num_classes > 0:
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights") tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear( self._fc = Linear(
self._out_c, self._out_c,
class_dim, num_classes,
act='softmax', act=classifier_activation,
param_attr=tmp_param, param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset")) bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs): def forward(self, inputs):
y = self._conv1(inputs, if_act=True) y = self._conv1(inputs, if_act=True)
for inv in self._invl: for inv in self._invl:
y = inv(y) y = inv(y)
y = self._conv9(y, if_act=True) y = self._conv9(y, if_act=True)
y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c]) if self.with_pool:
y = self._fc(y) y = self._pool2d_avg(y)
if self.num_classes > 0:
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y return y
def _mobilenet(arch, pretrained=False, **kwargs): def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs) model = MobileNetV2(num_classes=1000, with_pool=True, **kwargs)
if pretrained: if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch) arch)
...@@ -246,7 +253,8 @@ def mobilenet_v2(pretrained=False, scale=1.0): ...@@ -246,7 +253,8 @@ def mobilenet_v2(pretrained=False, scale=1.0):
"""MobileNetV2 """MobileNetV2
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
""" """
model = _mobilenet('mobilenetv2_' + str(scale), pretrained, scale=scale) model = _mobilenet('mobilenetv2_' + str(scale), pretrained, scale=scale)
return model return model
...@@ -164,11 +164,21 @@ class ResNet(Model): ...@@ -164,11 +164,21 @@ class ResNet(Model):
Block (BasicBlock|BottleneckBlock): block module of model. Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50. depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer, default: 1000. num_classes (int): output dim of last fc layer, default: 1000.
with_pool (bool): use pool or not. Default: False.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
""" """
def __init__(self, Block, depth=50, num_classes=1000): def __init__(self,
Block,
depth=50,
num_classes=-1,
with_pool=False,
classifier_activation='softmax'):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
layer_config = { layer_config = {
18: [2, 2, 2, 2], 18: [2, 2, 2, 2],
34: [3, 4, 6, 3], 34: [3, 4, 6, 3],
...@@ -212,31 +222,37 @@ class ResNet(Model): ...@@ -212,31 +222,37 @@ class ResNet(Model):
Sequential(*blocks)) Sequential(*blocks))
self.layers.append(layer) self.layers.append(layer)
self.global_pool = Pool2D( if with_pool:
pool_size=7, pool_type='avg', global_pooling=True) self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0) if num_classes > 0:
self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1 stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
self.fc = Linear( self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
self.fc_input_dim, self.fc = Linear(
num_classes, self.fc_input_dim,
act='softmax', num_classes,
param_attr=fluid.param_attr.ParamAttr( act=classifier_activation,
initializer=fluid.initializer.Uniform(-stdv, stdv))) param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs): def forward(self, inputs):
x = self.conv(inputs) x = self.conv(inputs)
x = self.pool(x) x = self.pool(x)
for layer in self.layers: for layer in self.layers:
x = layer(x) x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim]) if self.with_pool:
x = self.fc(x) x = self.global_pool(x)
if self.num_classes > -1:
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x return x
def _resnet(arch, Block, depth, pretrained): def _resnet(arch, Block, depth, pretrained):
model = ResNet(Block, depth) model = ResNet(Block, depth, num_classes=1000, with_pool=True)
if pretrained: if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch) arch)
......
...@@ -23,12 +23,8 @@ from .download import get_weights_path ...@@ -23,12 +23,8 @@ from .download import get_weights_path
__all__ = [ __all__ = [
'VGG', 'VGG',
'vgg11', 'vgg11',
'vgg11_bn',
'vgg13', 'vgg13',
'vgg13_bn',
'vgg16', 'vgg16',
'vgg16_bn',
'vgg19_bn',
'vgg19', 'vgg19',
] ]
...@@ -39,11 +35,11 @@ model_urls = { ...@@ -39,11 +35,11 @@ model_urls = {
class Classifier(fluid.dygraph.Layer): class Classifier(fluid.dygraph.Layer):
def __init__(self, num_classes): def __init__(self, num_classes, classifier_activation='softmax'):
super(Classifier, self).__init__() super(Classifier, self).__init__()
self.linear1 = Linear(512 * 7 * 7, 4096) self.linear1 = Linear(512 * 7 * 7, 4096)
self.linear2 = Linear(4096, 4096) self.linear2 = Linear(4096, 4096)
self.linear3 = Linear(4096, num_classes, act='softmax') self.linear3 = Linear(4096, num_classes, act=classifier_activation)
def forward(self, x): def forward(self, x):
x = self.linear1(x) x = self.linear1(x)
...@@ -62,20 +58,29 @@ class VGG(Model): ...@@ -62,20 +58,29 @@ class VGG(Model):
Args: Args:
features (fluid.dygraph.Layer): vgg features create by function make_layers. features (fluid.dygraph.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. Default: 1000. num_classes (int): output dim of last fc layer. Default: -1.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
""" """
def __init__(self, features, num_classes=1000): def __init__(self,
features,
num_classes=-1,
classifier_activation='softmax'):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
classifier = Classifier(num_classes) self.num_classes = num_classes
self.classifier = self.add_sublayer("classifier",
Sequential(classifier)) if num_classes > 0:
classifier = Classifier(num_classes, classifier_activation)
self.classifier = self.add_sublayer("classifier",
Sequential(classifier))
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = fluid.layers.flatten(x, 1)
x = self.classifier(x) if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.classifier(x)
return x return x
...@@ -114,7 +119,10 @@ cfgs = { ...@@ -114,7 +119,10 @@ cfgs = {
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(
cfgs[cfg], batch_norm=batch_norm),
num_classes=1000,
**kwargs)
if pretrained: if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
...@@ -128,73 +136,53 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): ...@@ -128,73 +136,53 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
return model return model
def vgg11(pretrained=False, **kwargs): def vgg11(pretrained=False, batch_norm=False):
"""VGG 11-layer model """VGG 11-layer model
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
""" """
return _vgg('vgg11', 'A', False, pretrained, **kwargs) model_name = 'vgg11'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'A', batch_norm, pretrained)
def vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg11_bn', 'A', True, pretrained, **kwargs)
def vgg13(pretrained=False, batch_norm=False):
def vgg13(pretrained=False, **kwargs):
"""VGG 13-layer model """VGG 13-layer model
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
""" batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
return _vgg('vgg13', 'B', False, pretrained, **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
return _vgg('vgg13_bn', 'B', True, pretrained, **kwargs) model_name = 'vgg13'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'B', batch_norm, pretrained)
def vgg16(pretrained=False, **kwargs): def vgg16(pretrained=False, batch_norm=False):
"""VGG 16-layer model """VGG 16-layer model
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
""" batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
return _vgg('vgg16', 'D', False, pretrained, **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
return _vgg('vgg16_bn', 'D', True, pretrained, **kwargs) model_name = 'vgg16'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'D', batch_norm, pretrained)
def vgg19(pretrained=False, **kwargs): def vgg19(pretrained=False, batch_norm=False):
"""VGG 19-layer model """VGG 19-layer model
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
""" batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
return _vgg('vgg19', 'E', False, pretrained, **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
return _vgg('vgg19_bn', 'E', True, pretrained, **kwargs) model_name = 'vgg19'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'E', batch_norm, pretrained)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册