From 76876f52f2eae540067ac176b9671a2c5955e71a Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Fri, 10 Apr 2020 10:09:08 +0000 Subject: [PATCH] fix reviews --- models/mobilenetv1.py | 9 +++++---- models/mobilenetv2.py | 9 +++++---- models/resnet.py | 9 +++++---- models/vgg.py | 5 +++-- transform/transforms.py | 3 ++- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/models/mobilenetv1.py b/models/mobilenetv1.py index 2ac0408..11f8799 100644 --- a/models/mobilenetv1.py +++ b/models/mobilenetv1.py @@ -111,15 +111,16 @@ class MobileNetV1(Model): Args: scale (float): scale of channels in each layer. Default: 1.0. - num_classes (int): output dim of last fc layer. Default: -1. - with_pool (bool): use pool or not. Default: False. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. classifier_activation (str): activation for the last fc layer. Default: 'softmax'. """ def __init__(self, scale=1.0, - num_classes=-1, - with_pool=False, + num_classes=1000, + with_pool=True, classifier_activation='softmax'): super(MobileNetV1, self).__init__() self.scale = scale diff --git a/models/mobilenetv2.py b/models/mobilenetv2.py index 59201f8..1d592fb 100644 --- a/models/mobilenetv2.py +++ b/models/mobilenetv2.py @@ -156,15 +156,16 @@ class MobileNetV2(Model): Args: scale (float): scale of channels in each layer. Default: 1.0. - num_classes (int): output dim of last fc layer. Default: -1. - with_pool (bool): use pool or not. Default: False. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. classifier_activation (str): activation for the last fc layer. Default: 'softmax'. """ def __init__(self, scale=1.0, - num_classes=-1, - with_pool=False, + num_classes=1000, + with_pool=True, classifier_activation='softmax'): super(MobileNetV2, self).__init__() self.scale = scale diff --git a/models/resnet.py b/models/resnet.py index 3e75b13..6999fb7 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -163,16 +163,17 @@ class ResNet(Model): Args: Block (BasicBlock|BottleneckBlock): block module of model. depth (int): layers of resnet, default: 50. - num_classes (int): output dim of last fc layer, default: 1000. - with_pool (bool): use pool or not. Default: False. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. classifier_activation (str): activation for the last fc layer. Default: 'softmax'. """ def __init__(self, Block, depth=50, - num_classes=-1, - with_pool=False, + num_classes=1000, + with_pool=True, classifier_activation='softmax'): super(ResNet, self).__init__() diff --git a/models/vgg.py b/models/vgg.py index d5991b3..324ddc0 100644 --- a/models/vgg.py +++ b/models/vgg.py @@ -58,13 +58,14 @@ class VGG(Model): Args: features (fluid.dygraph.Layer): vgg features create by function make_layers. - num_classes (int): output dim of last fc layer. Default: -1. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. classifier_activation (str): activation for the last fc layer. Default: 'softmax'. """ def __init__(self, features, - num_classes=-1, + num_classes=1000, classifier_activation='softmax'): super(VGG, self).__init__() self.features = features diff --git a/transform/transforms.py b/transform/transforms.py index 47a5454..17ae5ed 100644 --- a/transform/transforms.py +++ b/transform/transforms.py @@ -289,7 +289,8 @@ class Normalize(object): class Permute(object): """Change input data to a target mode. For example, most transforms use HWC mode image, - while the Neural Network might use CHW mode input tensor + while the Neural Network might use CHW mode input tensor. + Input image should be HWC mode and an instance of numpy.ndarray. Args: mode: Output mode of input. Use "CHW" mode by default. -- GitLab