diff --git a/PaddleCV/image_classification/README.md b/PaddleCV/image_classification/README.md index 7f8ac0a49c3d1a7b722a0c42a74b52baf07ef6b9..d7b4fe7b7974a9bd77d16487e64167f4203866a4 100644 --- a/PaddleCV/image_classification/README.md +++ b/PaddleCV/image_classification/README.md @@ -166,16 +166,23 @@ The image classification models currently supported by PaddlePaddle are listed i As the activation function ```swish``` and ```relu6``` which separately used in ShuffleNetV2_swish and MobileNetV2 net are not supported by Paddle TensorRT, inference acceleration performance of them doesn't significient improve. Pretrained models can be downloaded by clicking related model names. - Note1: ResNet50_vd_v2 is the distilled version of ResNet50_vd. -- Note2: In addition to the image resolution feeded in InceptionV4 and Xception net is ```299x299```, others are ```224x224```. +- Note2: The image resolution feeded in InceptionV4 and Xception net is ```299x299```, Fix_ResNeXt101_32x48d_wsl is ```320x320```, DarkNet is ```256x256```, others are ```224x224```.In test time, the resize_short_size of the DarkNet53 and Fix_ResNeXt101_32x48d_wsl series networks is the same as the width or height of the input image resolution, the InceptionV4 and Xception network resize_short_size is 320, and the other networks resize_short_size are 256. - Note3: It's necessary to convert the train model to a binary model when appling dynamic link library to infer, One can do it by running following command: ```python infer.py --save_inference=True``` +- Note4: The pretrained model of the ResNeXt101_wsl series network is converted from the pytorch model. Please go to [RESNEXT WSL](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) for details. ### AlexNet |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | |- |:-: |:-: |:-: | |[AlexNet](http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar) | 56.72%/79.17% | 3.083 | 2.728 | +### SqueezeNet +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[SqueezeNet1_0](https://paddle-imagenet-models-name.bj.bcebos.com/SqueezeNet1_0_pretrained.tar) | 59.60%/81.66% | 2.740 | 1.688 | +|[SqueezeNet1_1](https://paddle-imagenet-models-name.bj.bcebos.com/SqueezeNet1_1_pretrained.tar) | 60.08%/81.85% | 2.751 | 1.270 | + ### VGG |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | |- |:-: |:-: |:-: | @@ -224,12 +231,23 @@ As the activation function ```swish``` and ```relu6``` which separately used in |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | |- |:-: |:-: |:-: | |[ResNeXt50_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_32x4d_pretrained.tar) | 77.75%/93.82% | 12.863 | 9.837 | +|[ResNeXt50_vd_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_vd_32x4d_pretrained.tar) | 79.56%/94.62% | 13.673 | 9.991 | |[ResNeXt50_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_64x4d_pretrained.tar) | 78.43%/94.13% | 28.162 | 18.271 | |[ResNeXt50_vd_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_vd_64x4d_pretrained.tar) | 80.12%/94.86% | 20.888 | 17.687 | |[ResNeXt101_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x4d_pretrained.tar) | 78.65%/94.19% | 24.154 | 21.387 | |[ResNeXt101_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_64x4d_pretrained.tar) | 78.43%/94.13% | 41.073 | 38.736 | |[ResNeXt101_vd_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar) | 80.78%/95.20% | 42.277 | 40.929 | |[ResNeXt152_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt152_32x4d_pretrained.tar) | 78.98%/94.33% | 37.007 | 31.301 | +|[ResNeXt152_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt152_64x4d_pretrained.tar) | 79.51%/94.71% | 58.966 | 57.267 | + +### DenseNet +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[DenseNet121](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar) | 75.66%/92.58% | 12.437 | 5.813 | +|[DenseNet161](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar) | 78.57%/94.14% | 27.717 | 12.861 | +|[DenseNet169](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet169_pretrained.tar) | 76.81%/93.31% | 18.941 | 8.146 | +|[DenseNet201](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar) | 77.63%/93.66% | 26.583 | 10.549 | +|[DenseNet264](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet264_pretrained.tar) | 77.96%/93.85% | 41.495 | 15.574 | ### SENet |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | @@ -245,8 +263,19 @@ As the activation function ```swish``` and ```relu6``` which separately used in |[Xception_41](https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_pretrained.tar) | 79.30%/94.53% | 13.757 | 10.831 | |[InceptionV4](https://paddle-imagenet-models-name.bj.bcebos.com/InceptionV4_pretrained.tar) | 80.77%/95.26% | 32.413 | 18.154 | +### DarkNet +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[DarkNet53](https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar) | 78.04%/94.05% | 11.969 | 7.153 | - +### ResNeXt101_wsl +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[ResNeXt101_32x8d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x8d_wsl_pretrained.tar) | 82.55%/96.74% | 33.310 | 27.648 | +|[ResNeXt101_32x16d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x16d_wsl_pretrained.tar) | 84.24%/97.26% | 54.320 | 46.064 | +|[ResNeXt101_32x32d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x32d_wsl_pretrained.tar) | 84.97%/97.59% | 97.734 | 87.961 | +|[ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x48d_wsl_pretrained.tar) | 85.37%/97.69% | 161.722 | | +|[Fix_ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/Fix_ResNeXt101_32x48d_wsl_pretrained.tar) | 86.26%/97.97% | 236.091 | | ## FAQ @@ -272,7 +301,11 @@ Enforce failed. Expected x_dims[1] == labels_dims[1], but received x_dims[1]:100 - GoogLeNet: [Going Deeper with Convolutions](https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf), Christian Szegedy1, Wei Liu2, Yangqing Jia - Xception: [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357), Franc ̧ois Chollet - InceptionV4: [Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning](https://arxiv.org/abs/1602.07261), Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi - +- DarkNet: [YOLOv3: An Incremental Improvement](https://pjreddie.com/media/files/papers/YOLOv3.pdf), Joseph Redmon, Ali Farhadi +- DenseNet: [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993), Gao Huang, Zhuang Liu, Laurens van der Maaten +- SqueezeNet: [SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE](https://arxiv.org/abs/1602.07360), Forrest N. Iandola, Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, Kurt Keutzer +- ResNeXt101_wsl: [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Dhruv Mahajan, Ross Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe, Laurens van der Maaten +- Fix_ResNeXt101_wsl: [Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423), Hugo Touvron, Andrea Vedaldi, Matthijs Douze, Herve ́ Je ́gou ## Update @@ -284,6 +317,7 @@ Enforce failed. Expected x_dims[1] == labels_dims[1], but received x_dims[1]:100 - 2019/06/22 Update ResNet50_vd_v2 - 2019/07/02 **Stage5**: Update MobileNetV2_x0_5, ResNeXt50_32x4d, ResNeXt50_64x4d, Xception_41, ResNet101_vd - 2019/07/19 **Stage6**: Update ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, MobileNetV2_x0_25, MobileNetV2_x1_5, MobileNetV2_x2_0, ResNeXt50_vd_64x4d, ResNeXt101_32x4d, ResNeXt152_32x4d +- 2019/08/01 **Stage7**: Update DarkNet53, DenseNet121. Densenet161, DenseNet169, DenseNet201, DenseNet264, SqueezeNet1_0, SqueezeNet1_1, ResNeXt50_vd_32x4d, ResNeXt152_64x4d, ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl ## Contribute diff --git a/PaddleCV/image_classification/README_cn.md b/PaddleCV/image_classification/README_cn.md index 77e603c2cf020199ccc638e86fc8c98d3c3f9e1c..ec17a708512d9a54191a1663dae79f2e4378a11f 100644 --- a/PaddleCV/image_classification/README_cn.md +++ b/PaddleCV/image_classification/README_cn.md @@ -154,17 +154,25 @@ python infer.py \ 试GPU型号为Tesla P4)。由于Paddle TensorRT对ShuffleNetV2_swish使用的激活函数swish,MobileNetV2使用的激活函数relu6不支持,因此预测加速不明显。可以通过点击相应模型的名称下载对应的预训练模型。 - 注意 - 1:ResNet50_vd_v2是ResNet50_vd蒸馏版本。 - 2:除了InceptionV4和Xception采用的输入图像的分辨率为299x299,其余模型测试时使用的分辨率均为224x224。 - 3:调用动态链接库预测时需要将训练模型转换为二进制模型 + - 1:ResNet50_vd_v2是ResNet50_vd蒸馏版本。 + - 2:InceptionV4和Xception采用的输入图像的分辨率为299x299,DarkNet53为256x256,Fix_ResNeXt101_32x48d_wsl为320x320,其余模型使用的分辨率均为224x224。在预测时,DarkNet53与Fix_ResNeXt101_32x48d_wsl系列网络resize_short_size与输入的图像分辨率的宽或高相同,InceptionV4和Xception网络resize_short_size为320,其余网络resize_short_size均为256。 + - 3:调用动态链接库预测时需要将训练模型转换为二进制模型 ```python infer.py --save_inference=True``` + - 4: ResNeXt101_wsl系列的预训练模型转自pytorch模型,详情请移步[RESNEXT WSL](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)。 + ### AlexNet |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | |- |:-: |:-: |:-: | |[AlexNet](http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar) | 56.72%/79.17% | 3.083 | 2.728 | +### SqueezeNet +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[SqueezeNet1_0](https://paddle-imagenet-models-name.bj.bcebos.com/SqueezeNet1_0_pretrained.tar) | 59.60%/81.66% | 2.740 | 1.688 | +|[SqueezeNet1_1](https://paddle-imagenet-models-name.bj.bcebos.com/SqueezeNet1_1_pretrained.tar) | 60.08%/81.85% | 2.751 | 1.270 | + ### VGG |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | |- |:-: |:-: |:-: | @@ -213,12 +221,23 @@ python infer.py \ |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | |- |:-: |:-: |:-: | |[ResNeXt50_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_32x4d_pretrained.tar) | 77.75%/93.82% | 12.863 | 9.837 | +|[ResNeXt50_vd_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_vd_32x4d_pretrained.tar) | 79.56%/94.62% | 13.673 | 9.991 | |[ResNeXt50_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_64x4d_pretrained.tar) | 78.43%/94.13% | 28.162 | 18.271 | |[ResNeXt50_vd_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_vd_64x4d_pretrained.tar) | 80.12%/94.86% | 20.888 | 17.687 | |[ResNeXt101_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x4d_pretrained.tar) | 78.65%/94.19% | 24.154 | 21.387 | |[ResNeXt101_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt50_64x4d_pretrained.tar) | 78.43%/94.13% | 41.073 | 38.736 | |[ResNeXt101_vd_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar) | 80.78%/95.20% | 42.277 | 40.929 | |[ResNeXt152_32x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt152_32x4d_pretrained.tar) | 78.98%/94.33% | 37.007 | 31.301 | +|[ResNeXt152_64x4d](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt152_64x4d_pretrained.tar) | 79.51%/94.71% | 58.966 | 57.267 | + +### DenseNet +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[DenseNet121](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar) | 75.66%/92.58% | 12.437 | 5.813 | +|[DenseNet161](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar) | 78.57%/94.14% | 27.717 | 12.861 | +|[DenseNet169](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet169_pretrained.tar) | 76.81%/93.31% | 18.941 | 8.146 | +|[DenseNet201](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar) | 77.63%/93.66% | 26.583 | 10.549 | +|[DenseNet264](https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet264_pretrained.tar) | 77.96%/93.85% | 41.495 | 15.574 | ### SENet |model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | @@ -234,7 +253,19 @@ python infer.py \ |[Xception_41](https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_pretrained.tar) | 79.30%/94.53% | 13.757 | 10.831 | |[InceptionV4](https://paddle-imagenet-models-name.bj.bcebos.com/InceptionV4_pretrained.tar) | 80.77%/95.26% | 32.413 | 18.154 | +### DarkNet +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[DarkNet53](https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar) | 78.04%/94.05% | 11.969 | 7.153 | +### ResNeXt101_wsl +|model | top-1/top-5 accuracy(CV2) | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: | +|[ResNeXt101_32x8d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x8d_wsl_pretrained.tar) | 82.55%/96.74% | 33.310 | 27.648 | +|[ResNeXt101_32x16d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x16d_wsl_pretrained.tar) | 84.24%/97.26% | 54.320 | 46.064 | +|[ResNeXt101_32x32d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x32d_wsl_pretrained.tar) | 84.97%/97.59% | 97.734 | 87.961 | +|[ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x48d_wsl_pretrained.tar) | 85.37%/97.69% | 161.722 | | +|[Fix_ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/Fix_ResNeXt101_32x48d_wsl_pretrained.tar) | 86.26%/97.97% | 236.091 | | ## FAQ @@ -256,6 +287,11 @@ python infer.py \ - GoogLeNet: [Going Deeper with Convolutions](https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf), Christian Szegedy1, Wei Liu2, Yangqing Jia - Xception: [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357), Franc ̧ois Chollet - InceptionV4: [Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning](https://arxiv.org/abs/1602.07261), Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi +- DarkNet: [YOLOv3: An Incremental Improvement](https://pjreddie.com/media/files/papers/YOLOv3.pdf), Joseph Redmon, Ali Farhadi +- DenseNet: [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993), Gao Huang, Zhuang Liu, Laurens van der Maaten +- SqueezeNet: [SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE](https://arxiv.org/abs/1602.07360), Forrest N. Iandola, Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, Kurt Keutzer +- ResNeXt101_wsl: [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Dhruv Mahajan, Ross Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe, Laurens van der Maaten +- Fix_ResNeXt101_wsl: [Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423), Hugo Touvron, Andrea Vedaldi, Matthijs Douze, Herve ́ Je ́gou ## 版本更新 - 2018/12/03 **Stage1**: 更新AlexNet,ResNet50,ResNet101,MobileNetV1 @@ -266,6 +302,7 @@ python infer.py \ - 2019/06/22 更新ResNet50_vd_v2 - 2019/07/02 **Stage5**: 更新MobileNetV2_x0_5, ResNeXt50_32x4d, ResNeXt50_64x4d, Xception_41, ResNet101_vd - 2019/07/19 **Stage6**: 更新ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, MobileNetV2_x0_25, MobileNetV2_x1_5, MobileNetV2_x2_0, ResNeXt50_vd_64x4d, ResNeXt101_32x4d, ResNeXt152_32x4d +- 2019/08/01 **Stage7**: 更新DarkNet53, DenseNet121. Densenet161, DenseNet169, DenseNet201, DenseNet264, SqueezeNet1_0, SqueezeNet1_1, ResNeXt50_vd_32x4d, ResNeXt152_64x4d, ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl ## 如何贡献代码 diff --git a/PaddleCV/image_classification/models/__init__.py b/PaddleCV/image_classification/models/__init__.py index 86ce67483634d315508224a7ada193f96b5468cf..1ab3c60906d2c7b3766a66bee93d51f71da325fc 100644 --- a/PaddleCV/image_classification/models/__init__.py +++ b/PaddleCV/image_classification/models/__init__.py @@ -17,3 +17,8 @@ from .shufflenet_v2_swish import ShuffleNetV2, ShuffleNetV2_x0_5_swish, ShuffleN from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0 from .fast_imagenet import FastImageNet from .xception import Xception_41, Xception_65, Xception_71 +from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264 +from .squeezenet import SqueezeNet1_0, SqueezeNet1_1 +from .darknet import DarkNet53 +from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl + diff --git a/PaddleCV/image_classification/models/darknet.py b/PaddleCV/image_classification/models/darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..b545ed7396e643b4456bdc82ba954b7b6c9d4d65 --- /dev/null +++ b/PaddleCV/image_classification/models/darknet.py @@ -0,0 +1,115 @@ +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +import math +__all__ = ["DarkNet53"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class DarkNet53(): + def __init__(self): + self.params = train_parameters + + def net(self, input, class_dim=1000): + DarkNet_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)} + stages, block_func = DarkNet_cfg[53] + stages = stages[0:5] + conv1 = self.conv_bn_layer( + input, + ch_out=32, + filter_size=3, + stride=1, + padding=1, + name="yolo_input") + conv = self.downsample( + conv1, + ch_out=conv1.shape[1] * 2, + name="yolo_input.downsample") + + for i, stage in enumerate(stages): + conv = self.layer_warp( + block_func, + conv, + 32 * (2**i), + stage, + name="stage.{}".format(i)) + if i < len(stages) - 1: # do not downsaple in the last stage + conv = self.downsample( + conv, + ch_out=conv.shape[1] * 2, + name="stage.{}.downsample".format(i)) + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + param_attr=ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'), + bias_attr=ParamAttr(name='fc_offset')) + return out + + + + + + def conv_bn_layer(self, input, ch_out, filter_size, stride, padding, name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(name=name + ".conv.weights"), + bias_attr=False) + + bn_name = name + ".bn" + out = fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=ParamAttr(name=bn_name + '.scale'), + bias_attr=ParamAttr(name=bn_name + '.offset'), + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + return out + + + + def downsample(self, input, ch_out, filter_size=3, stride=2, padding=1, name=None): + return self.conv_bn_layer( + input, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + name=name) + + + def basicblock(self, input, ch_out, name=None): + conv1 = self.conv_bn_layer( + input, ch_out, 1, 1, 0, name=name + ".0") + conv2 = self.conv_bn_layer( + conv1, ch_out * 2, 3, 1, 1, name=name + ".1") + out = fluid.layers.elementwise_add(x=input, y=conv2, act=None) + return out + + + def layer_warp(self, block_func, input, ch_out, count, name=None): + res_out = block_func( + input, ch_out, name='{}.0'.format(name)) + for j in range(1, count): + res_out = block_func( + res_out, ch_out, name='{}.{}'.format(name, j)) + return res_out + + diff --git a/PaddleCV/image_classification/models/densenet.py b/PaddleCV/image_classification/models/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..95173e1f1def01c71329535d18474ef1d2b90684 --- /dev/null +++ b/PaddleCV/image_classification/models/densenet.py @@ -0,0 +1,167 @@ +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["DenseNet", "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", "DenseNet264"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + +class DenseNet(): + def __init__(self, layers=121): + self.params = train_parameters + self.layers = layers + + + def net(self, input, bn_size=4, dropout=0, class_dim=1000): + layers = self.layers + supported_layers = [121, 161, 169, 201, 264] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + densenet_spec = {121: (64, 32, [6, 12, 24, 16]), + 161: (96, 48, [6, 12, 36, 24]), + 169: (64, 32, [6, 12, 32, 32]), + 201: (64, 32, [6, 12, 48, 32]), + 264: (64, 32, [6, 12, 64, 48])} + + + num_init_features, growth_rate, block_config = densenet_spec[layers] + conv = fluid.layers.conv2d( + input=input, + num_filters=num_init_features, + filter_size=7, + stride=2, + padding=3, + act=None, + param_attr=ParamAttr(name="conv1_weights"), + bias_attr=False) + conv = fluid.layers.batch_norm(input=conv, + act='relu', + param_attr=ParamAttr(name='conv1_bn_scale'), + bias_attr=ParamAttr(name='conv1_bn_offset'), + moving_mean_name='conv1_bn_mean', + moving_variance_name='conv1_bn_variance') + conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + num_features = num_init_features + for i, num_layers in enumerate(block_config): + conv = self.make_dense_block(conv, num_layers, bn_size, growth_rate, dropout, name='conv'+str(i+2)) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + conv = self.make_transition(conv, num_features // 2, name='conv'+str(i+2)+'_blk') + num_features = num_features // 2 + conv = fluid.layers.batch_norm(input=conv, + act='relu', + param_attr=ParamAttr(name='conv5_blk_bn_scale'), + bias_attr=ParamAttr(name='conv5_blk_bn_offset'), + moving_mean_name='conv5_blk_bn_mean', + moving_variance_name='conv5_blk_bn_variance') + conv = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(conv.shape[1] * 1.0) + out = fluid.layers.fc(input=conv, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_weights"), + bias_attr=ParamAttr(name='fc_offset')) + return out + + + + def make_transition(self, input, num_output_features, name=None): + bn_ac = fluid.layers.batch_norm(input, + act='relu', + param_attr=ParamAttr(name=name + '_bn_scale'), + bias_attr=ParamAttr(name + '_bn_offset'), + moving_mean_name=name + '_bn_mean', + moving_variance_name=name + '_bn_variance' + ) + + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=num_output_features, + filter_size=1, + stride=1, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_weights") + ) + pool = fluid.layers.pool2d(input=bn_ac_conv, pool_size=2, pool_stride=2, pool_type='avg') + return pool + + def make_dense_block(self, input, num_layers, bn_size, growth_rate, dropout, name=None): + conv = input + for layer in range(num_layers): + conv = self.make_dense_layer(conv, growth_rate, bn_size, dropout, name=name + '_' + str(layer+1)) + return conv + + + def make_dense_layer(self, input, growth_rate, bn_size, dropout, name=None): + bn_ac = fluid.layers.batch_norm(input, + act='relu', + param_attr=ParamAttr(name=name + '_x1_bn_scale'), + bias_attr=ParamAttr(name + '_x1_bn_offset'), + moving_mean_name=name + '_x1_bn_mean', + moving_variance_name=name + '_x1_bn_variance') + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=bn_size * growth_rate, + filter_size=1, + stride=1, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_x1_weights")) + bn_ac = fluid.layers.batch_norm(bn_ac_conv, + act='relu', + param_attr=ParamAttr(name=name + '_x2_bn_scale'), + bias_attr=ParamAttr(name + '_x2_bn_offset'), + moving_mean_name=name + '_x2_bn_mean', + moving_variance_name=name + '_x2_bn_variance') + bn_ac_conv = fluid.layers.conv2d( + input=bn_ac, + num_filters=growth_rate, + filter_size=3, + stride=1, + padding=1, + act=None, + bias_attr=False, + param_attr=ParamAttr(name=name + "_x2_weights")) + if dropout: + bn_ac_conv = fluid.layers.dropout(x=bn_ac_conv, dropout_prob=dropout) + bn_ac_conv = fluid.layers.concat([input, bn_ac_conv], axis=1) + return bn_ac_conv + + +def DenseNet121(): + model=DenseNet(layers=121) + return model + +def DenseNet161(): + model=DenseNet(layers=161) + return model + +def DenseNet169(): + model=DenseNet(layers=169) + return model + +def DenseNet201(): + model=DenseNet(layers=201) + return model + +def DenseNet264(): + model=DenseNet(layers=264) + return model + + + + + + diff --git a/PaddleCV/image_classification/models/resnext101_wsl.py b/PaddleCV/image_classification/models/resnext101_wsl.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5a65e4253d1eac439dc72756cb734fe7179eb7 --- /dev/null +++ b/PaddleCV/image_classification/models/resnext101_wsl.py @@ -0,0 +1,171 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl", "ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl", "Fix_ResNeXt101_32x48d_wsl"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNeXt101_wsl(): + def __init__(self, layers=101, cardinality=32, width=48): + self.params = train_parameters + self.layers = layers + self.cardinality = cardinality + self.width = width + + def net(self, input, class_dim=1000): + layers = self.layers + cardinality = self.cardinality + width = self.width + + depth = [3, 4, 23, 3] + base_width = cardinality * width + num_filters = [base_width * i for i in [1, 2, 4, 8]] + + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1") #debug + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + conv_name = 'layer' + str(block+1) + "." + str(i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc.weight'), + bias_attr=fluid.param_attr.ParamAttr(name='fc.bias')) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + if "downsample" in name: + conv_name = name + '.0' + else: + conv_name = name + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=conv_name + ".weight"), + bias_attr=False) + if "downsample" in name: + bn_name = name[:9] + 'downsample' + '.1' + else: + if "conv1" == name: + bn_name = 'bn' + name[-1] + else: + bn_name = (name[:10] if name[7:9].isdigit() else name[:9]) + 'bn' + name[-1] + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + '.weight'), + bias_attr=ParamAttr(bn_name + '.bias'), + moving_mean_name=bn_name + '.running_mean', + moving_variance_name=bn_name + '.running_var', ) + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, cardinality, name): + cardinality = self.cardinality + width = self.width + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + ".conv1") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu', + name=name + ".conv2") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters//(width//8), + filter_size=1, + act=None, + name=name + ".conv3") + + short = self.shortcut( + input, num_filters//(width//8), stride, name=name + ".downsample") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu') + + + + +def ResNeXt101_32x8d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=8) + return model + +def ResNeXt101_32x16d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=16) + return model + +def ResNeXt101_32x32d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=32) + return model + +def ResNeXt101_32x48d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=48) + return model + + +def Fix_ResNeXt101_32x48d_wsl(): + model = ResNeXt101_wsl(cardinality=32, width=48) + return model diff --git a/PaddleCV/image_classification/models/resnext_vd.py b/PaddleCV/image_classification/models/resnext_vd.py index df01d5e28a957a301a875939e5764edd910a749d..9b1a52969f225429cfd7c5e6a8025fc3b825b4cc 100644 --- a/PaddleCV/image_classification/models/resnext_vd.py +++ b/PaddleCV/image_classification/models/resnext_vd.py @@ -198,7 +198,7 @@ def ResNeXt101_vd_64x4d(): return model def ResNeXt101_vd_32x4d(): - model = ResNeXt(layers=50, cardinality=32, is_3x3 = True) + model = ResNeXt(layers=101, cardinality=32, is_3x3 = True) return model def ResNeXt152_vd_64x4d(): @@ -206,6 +206,6 @@ def ResNeXt152_vd_64x4d(): return model def ResNeXt152_vd_32x4d(): - model = ResNeXt(layers=50, cardinality=32, is_3x3 = True) + model = ResNeXt(layers=152, cardinality=32, is_3x3 = True) return model diff --git a/PaddleCV/image_classification/models/shufflenet_v2.py b/PaddleCV/image_classification/models/shufflenet_v2.py index f5d68166b88655b873ce8169b5ff320b24e02856..bd20ee25bd56e101a5d2cd7f9d75849bfd06b9b4 100644 --- a/PaddleCV/image_classification/models/shufflenet_v2.py +++ b/PaddleCV/image_classification/models/shufflenet_v2.py @@ -46,7 +46,7 @@ class ShuffleNetV2(): scale = self.scale stage_repeats = [4, 8, 4] - if scale == 0.25: + if scale == 0.25: stage_out_channels = [-1, 24, 24, 48, 96, 512] elif scale == 0.33: stage_out_channels = [-1, 24, 32, 64, 128, 512] diff --git a/PaddleCV/image_classification/models/squeezenet.py b/PaddleCV/image_classification/models/squeezenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1123b880e7c7c2e48a595a8212ebf0042ee848f3 --- /dev/null +++ b/PaddleCV/image_classification/models/squeezenet.py @@ -0,0 +1,103 @@ +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["SqueezeNet", "SqueezeNet1_0", "SqueezeNet1_1"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + +class SqueezeNet(): + def __init__(self, version='1.0'): + self.params = train_parameters + self.version = version + + def net(self, input, class_dim=1000): + version = self.version + assert version in ['1.0', '1.1'], \ + "supported version are {} but input version is {}".format(['1.0', '1.1'], version) + if version == '1.0': + conv = fluid.layers.conv2d(input, + num_filters=96, + filter_size=7, + stride=2, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"), + bias_attr=ParamAttr(name='conv1_offset')) + conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2,pool_type='max') + conv = self.make_fire(conv, 16, 64, 64, name='fire2') + conv = self.make_fire(conv, 16, 64, 64, name='fire3') + conv = self.make_fire(conv, 32, 128, 128, name='fire4') + conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 32, 128, 128, name='fire5') + conv = self.make_fire(conv, 48, 192, 192, name='fire6') + conv = self.make_fire(conv, 48, 192, 192, name='fire7') + conv = self.make_fire(conv, 64, 256, 256, name='fire8') + conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 64, 256, 256, name='fire9') + else: + conv = fluid.layers.conv2d(input, + num_filters=64, + filter_size=3, + stride=2, + padding=1, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"), + bias_attr=ParamAttr(name='conv1_offset')) + conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 16, 64, 64, name='fire2') + conv = self.make_fire(conv, 16, 64, 64, name='fire3') + conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 32, 128, 128, name='fire4') + conv = self.make_fire(conv, 32, 128, 128, name='fire5') + conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max') + conv = self.make_fire(conv, 48, 192, 192, name='fire6') + conv = self.make_fire(conv, 48, 192, 192, name='fire7') + conv = self.make_fire(conv, 64, 256, 256, name='fire8') + conv = self.make_fire(conv, 64, 256, 256, name='fire9') + conv = fluid.layers.dropout(conv, dropout_prob=0.5) + conv = fluid.layers.conv2d(conv, + num_filters=class_dim, + filter_size=1, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name="conv10_weights"), + bias_attr=ParamAttr(name='conv10_offset')) + conv = fluid.layers.pool2d(conv, pool_type='avg', global_pooling=True) + out = fluid.layers.flatten(conv) + return out + + + def make_fire_conv(self, input, num_filters, filter_size, padding=0, name=None): + conv = fluid.layers.conv2d(input, + num_filters=num_filters, + filter_size=filter_size, + padding=padding, + act='relu', + param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"), + bias_attr=ParamAttr(name=name + '_offset')) + return conv + + def make_fire(self, input, squeeze_channels, expand1x1_channels, expand3x3_channels, name=None): + conv = self.make_fire_conv(input, squeeze_channels, 1, name=name+'_squeeze1x1') + conv_path1 = self.make_fire_conv(conv, expand1x1_channels, 1, name=name+'_expand1x1') + conv_path2 = self.make_fire_conv(conv, expand3x3_channels, 3, 1, name=name+'_expand3x3') + out = fluid.layers.concat([conv_path1, conv_path2], axis=1) + return out + +def SqueezeNet1_0(): + model = SqueezeNet(version='1.0') + return model + +def SqueezeNet1_1(): + model = SqueezeNet(version='1.1') + return model diff --git a/PaddleCV/image_classification/models/xception.py b/PaddleCV/image_classification/models/xception.py index d7a301c3e3df860c46f581da62e7bd323a8badf0..2a0874c688beb365fef21e5ea0cc6b1a1145c94c 100644 --- a/PaddleCV/image_classification/models/xception.py +++ b/PaddleCV/image_classification/models/xception.py @@ -227,7 +227,7 @@ class Xception(object): num_filters=num_filters, filter_size=filter_size, stride=stride, - padding=(filter_size - 1) / 2, + padding=(filter_size - 1) // 2, groups=groups, act=None, param_attr=ParamAttr(name=name + "_weights"), diff --git a/PaddleCV/image_classification/reader_cv2.py b/PaddleCV/image_classification/reader_cv2.py index dadaf37af55634120afd8610f9833e8c751e5c7f..371d2b5d1b3ff24136c46c3f4ac7a756d6957ce0 100644 --- a/PaddleCV/image_classification/reader_cv2.py +++ b/PaddleCV/image_classification/reader_cv2.py @@ -192,7 +192,6 @@ def process_image(sample, if crop_size > 0: target_size = settings.resize_short_size img = resize_short(img, target_size) - img = crop_image(img, target_size=crop_size, center=True) img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 @@ -209,10 +208,11 @@ def process_image(sample, def process_batch_data(input_data, settings, mode, color_jitter, rotate): batch_data = [] + crop_size = int(settings.image_shape.split(',')[-1]) for sample in input_data: if os.path.isfile(sample[0]): batch_data.append( - process_image(sample, settings, mode, color_jitter, rotate)) + process_image(sample, settings, mode, color_jitter, rotate, crop_size)) else: print("File not exist : %s" % sample[0]) return batch_data diff --git a/PaddleCV/image_classification/run.sh b/PaddleCV/image_classification/run.sh index 97edf28fdf753d4fda7bff1507e31c9e8079f149..b2466a79d8aa9449e40f7d95359ec496145807ea 100644 --- a/PaddleCV/image_classification/run.sh +++ b/PaddleCV/image_classification/run.sh @@ -28,6 +28,34 @@ python train.py \ # --lr=0.01 \ # --l2_decay=1e-4 +#SqueezeNet1_0 +#python train.py \ +# --model=SqueezeNet1_0 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --lr_strategy=cosine_decay \ +# --class_dim=1000 \ +# --model_save_dir=output/ \ +# --lr=0.02 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --l2_decay=1e-4 + +#SqueezeNet1_1 +#python train.py \ +# --model=SqueezeNet1_1 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --lr_strategy=cosine_decay \ +# --class_dim=1000 \ +# --model_save_dir=output/ \ +# --lr=0.02 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --l2_decay=1e-4 + #VGG11: #python train.py \ # --model=VGG11 \ @@ -448,6 +476,22 @@ python train.py \ # --model_save_dir=output/ \ # --l2_decay=1e-4 +#ResNeXt50_vd_32x4d +#python train.py \ +# --model=ResNeXt50_vd_32x4d \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=cosine_decay \ +# --lr=0.1 \ +# --num_epochs=200 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 \ +# --use_mixup=True \ +# --use_label_smoothing=True \ +# --label_smoothing_epsilon=0.1 \ #ResNeXt50_64x4d #python train.py \ @@ -539,6 +583,90 @@ python train.py \ # --model_save_dir=output/ \ # --l2_decay=1e-4 +#ResNeXt152_64x4d +#python train.py \ +# --model=ResNeXt152_64x4d \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=piecewise_decay \ +# --lr=0.1 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=18e-5 + +# DenseNet121 +# python train.py \ +# --model=DenseNet121 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=piecewise_decay \ +# --lr=0.1 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 + +# DenseNet161 +# python train.py \ +# --model=DenseNet161 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=piecewise_decay \ +# --lr=0.1 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 + +# DenseNet169 +# python train.py \ +# --model=DenseNet169 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=piecewise_decay \ +# --lr=0.1 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 + +# DenseNet201 +# python train.py \ +# --model=DenseNet201 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=piecewise_decay \ +# --lr=0.1 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 + +# DenseNet264 +# python train.py \ +# --model=DenseNet264 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,224,224 \ +# --class_dim=1000 \ +# --lr_strategy=piecewise_decay \ +# --lr=0.1 \ +# --num_epochs=120 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 + #SE_ResNeXt50_32x4d: #python train.py \ # --model=SE_ResNeXt50_32x4d \ @@ -631,6 +759,24 @@ python train.py \ # --use_label_smoothing=True \ # --label_smoothing_epsilon=0.1 \ +#DarkNet53 + python train.py +# --model=DarkNet53 \ +# --batch_size=256 \ +# --total_images=1281167 \ +# --image_shape=3,256,256 \ +# --class_dim=1000 \ +# --lr_strategy=cosine_decay \ +# --lr=0.1 \ +# --num_epochs=200 \ +# --with_mem_opt=True \ +# --model_save_dir=output/ \ +# --l2_decay=1e-4 \ +# --use_mixup=True \ +# --resize_short_size=256 \ +# --use_label_smoothing=True \ +# --label_smoothing_epsilon=0.1 \ + #ResNet50 nGraph: # Training: #OMP_NUM_THREADS=`nproc` FLAGS_use_ngraph=true python train.py \