diff --git a/deploy/configs/inference_cartoon.yaml b/deploy/configs/inference_cartoon.yaml index 7d93d98cc0696d8e1508e02db2cc864d6f917d19..e79da55090130223466fd6b6a078b9909d6e26f2 100644 --- a/deploy/configs/inference_cartoon.yaml +++ b/deploy/configs/inference_cartoon.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground use_gpu: True diff --git a/deploy/configs/inference_det.yaml b/deploy/configs/inference_det.yaml index c809a0257bc7c5b774f20fb3edb50a08e7d67bbb..dab7908ef7f59bfed077d9189811aedb650b0e92 100644 --- a/deploy/configs/inference_det.yaml +++ b/deploy/configs/inference_det.yaml @@ -5,7 +5,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 1 - labe_list: + label_list: - foreground # inference engine config diff --git a/deploy/configs/inference_drink.yaml b/deploy/configs/inference_drink.yaml index d044965f446634dcc151fd496a9d7b403b869d68..1c3e2c29aa8ddd5db46bbc8660c9f45942696a9c 100644 --- a/deploy/configs/inference_drink.yaml +++ b/deploy/configs/inference_drink.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground use_gpu: True diff --git a/deploy/configs/inference_general.yaml b/deploy/configs/inference_general.yaml index 6b397b5047b427d02014060380112b096e0b2da2..8fb8ae3a56697b882be00da554f33750ead42f70 100644 --- a/deploy/configs/inference_general.yaml +++ b/deploy/configs/inference_general.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground use_gpu: True diff --git a/deploy/configs/inference_general_binary.yaml b/deploy/configs/inference_general_binary.yaml index d76dae8f8f7c70f27996f6b20fd623bdc00bc441..72ec31fc438d1f884bada59507a90d172ab4a416 100644 --- a/deploy/configs/inference_general_binary.yaml +++ b/deploy/configs/inference_general_binary.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground use_gpu: True diff --git a/deploy/configs/inference_logo.yaml b/deploy/configs/inference_logo.yaml index f78ca25a042b3224a973d81f7b0242ace7c25430..2b8228eab772f8b1488275163518a6e059a49c53 100644 --- a/deploy/configs/inference_logo.yaml +++ b/deploy/configs/inference_logo.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground use_gpu: True diff --git a/deploy/configs/inference_product.yaml b/deploy/configs/inference_product.yaml index e7b494c383aa5f42b4515446805b1357ba43107c..78ba32068cb696e897c39d516e66b323bd12ad61 100644 --- a/deploy/configs/inference_product.yaml +++ b/deploy/configs/inference_product.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground # inference engine config diff --git a/deploy/configs/inference_vehicle.yaml b/deploy/configs/inference_vehicle.yaml index d99f42ad684150f1efeaf65f031ee1ea707fee37..e289e9f523b061dd26b8d687e594499dd7cdec37 100644 --- a/deploy/configs/inference_vehicle.yaml +++ b/deploy/configs/inference_vehicle.yaml @@ -8,7 +8,7 @@ Global: image_shape: [3, 640, 640] threshold: 0.2 max_det_results: 5 - labe_list: + label_list: - foreground use_gpu: True diff --git a/deploy/cpp_shitu/include/object_detector.h b/deploy/cpp_shitu/include/object_detector.h index 5bfc56253b1845a50f3b6b093db314e97505cfef..6855a0dcc84c2711283fe8d23ba1d2afe376fb0e 100644 --- a/deploy/cpp_shitu/include/object_detector.h +++ b/deploy/cpp_shitu/include/object_detector.h @@ -33,106 +33,106 @@ using namespace paddle_infer; namespace Detection { // Object Detection Result - struct ObjectResult { - // Rectangle coordinates of detected object: left, right, top, down - std::vector rect; - // Class id of detected object - int class_id; - // Confidence of detected object - float confidence; - }; +struct ObjectResult { + // Rectangle coordinates of detected object: left, right, top, down + std::vector rect; + // Class id of detected object + int class_id; + // Confidence of detected object + float confidence; +}; // Generate visualization colormap for each class - std::vector GenerateColorMap(int num_class); +std::vector GenerateColorMap(int num_class); // Visualiztion Detection Result - cv::Mat VisualizeResult(const cv::Mat &img, - const std::vector &results, - const std::vector &lables, - const std::vector &colormap, const bool is_rbox); - - class ObjectDetector { - public: - explicit ObjectDetector(const YAML::Node &config_file) { - this->use_gpu_ = config_file["Global"]["use_gpu"].as(); - if (config_file["Global"]["gpu_id"].IsDefined()) - this->gpu_id_ = config_file["Global"]["gpu_id"].as(); - this->gpu_mem_ = config_file["Global"]["gpu_mem"].as(); - this->cpu_math_library_num_threads_ = - config_file["Global"]["cpu_num_threads"].as(); - this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as(); - this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as(); - this->use_fp16_ = config_file["Global"]["use_fp16"].as(); - this->model_dir_ = - config_file["Global"]["det_inference_model_dir"].as(); - this->threshold_ = config_file["Global"]["threshold"].as(); - this->max_det_results_ = config_file["Global"]["max_det_results"].as(); - this->image_shape_ = - config_file["Global"]["image_shape"].as < std::vector < int >> (); - this->label_list_ = - config_file["Global"]["labe_list"].as < std::vector < std::string >> (); - this->ir_optim_ = config_file["Global"]["ir_optim"].as(); - this->batch_size_ = config_file["Global"]["batch_size"].as(); - - preprocessor_.Init(config_file["DetPreProcess"]["transform_ops"]); - LoadModel(model_dir_, batch_size_, run_mode); - } - - // Load Paddle inference model - void LoadModel(const std::string &model_dir, const int batch_size = 1, - const std::string &run_mode = "fluid"); - - // Run predictor - void Predict(const std::vector imgs, const int warmup = 0, - const int repeats = 1, - std::vector *result = nullptr, - std::vector *bbox_num = nullptr, - std::vector *times = nullptr); - - const std::vector &GetLabelList() const { - return this->label_list_; - } - - const float &GetThreshold() const { return this->threshold_; } - - private: - bool use_gpu_ = true; - int gpu_id_ = 0; - int gpu_mem_ = 800; - int cpu_math_library_num_threads_ = 6; - std::string run_mode = "fluid"; - bool use_mkldnn_ = false; - bool use_tensorrt_ = false; - bool batch_size_ = 1; - bool use_fp16_ = false; - std::string model_dir_; - float threshold_ = 0.5; - float max_det_results_ = 5; - std::vector image_shape_ = {3, 640, 640}; - std::vector label_list_; - bool ir_optim_ = true; - bool det_permute_ = true; - bool det_postprocess_ = true; - int min_subgraph_size_ = 30; - bool use_dynamic_shape_ = false; - int trt_min_shape_ = 1; - int trt_max_shape_ = 1280; - int trt_opt_shape_ = 640; - bool trt_calib_mode_ = false; - - // Preprocess image and copy data to input buffer - void Preprocess(const cv::Mat &image_mat); - - // Postprocess result - void Postprocess(const std::vector mats, - std::vector *result, std::vector bbox_num, - bool is_rbox); - - std::shared_ptr predictor_; - Preprocessor preprocessor_; - ImageBlob inputs_; - std::vector output_data_; - std::vector out_bbox_num_data_; - }; +cv::Mat VisualizeResult(const cv::Mat &img, + const std::vector &results, + const std::vector &lables, + const std::vector &colormap, const bool is_rbox); + +class ObjectDetector { +public: + explicit ObjectDetector(const YAML::Node &config_file) { + this->use_gpu_ = config_file["Global"]["use_gpu"].as(); + if (config_file["Global"]["gpu_id"].IsDefined()) + this->gpu_id_ = config_file["Global"]["gpu_id"].as(); + this->gpu_mem_ = config_file["Global"]["gpu_mem"].as(); + this->cpu_math_library_num_threads_ = + config_file["Global"]["cpu_num_threads"].as(); + this->use_mkldnn_ = config_file["Global"]["enable_mkldnn"].as(); + this->use_tensorrt_ = config_file["Global"]["use_tensorrt"].as(); + this->use_fp16_ = config_file["Global"]["use_fp16"].as(); + this->model_dir_ = + config_file["Global"]["det_inference_model_dir"].as(); + this->threshold_ = config_file["Global"]["threshold"].as(); + this->max_det_results_ = config_file["Global"]["max_det_results"].as(); + this->image_shape_ = + config_file["Global"]["image_shape"].as>(); + this->label_list_ = + config_file["Global"]["label_list"].as>(); + this->ir_optim_ = config_file["Global"]["ir_optim"].as(); + this->batch_size_ = config_file["Global"]["batch_size"].as(); + + preprocessor_.Init(config_file["DetPreProcess"]["transform_ops"]); + LoadModel(model_dir_, batch_size_, run_mode); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir, const int batch_size = 1, + const std::string &run_mode = "fluid"); + + // Run predictor + void Predict(const std::vector imgs, const int warmup = 0, + const int repeats = 1, + std::vector *result = nullptr, + std::vector *bbox_num = nullptr, + std::vector *times = nullptr); + + const std::vector &GetLabelList() const { + return this->label_list_; + } + + const float &GetThreshold() const { return this->threshold_; } + +private: + bool use_gpu_ = true; + int gpu_id_ = 0; + int gpu_mem_ = 800; + int cpu_math_library_num_threads_ = 6; + std::string run_mode = "fluid"; + bool use_mkldnn_ = false; + bool use_tensorrt_ = false; + bool batch_size_ = 1; + bool use_fp16_ = false; + std::string model_dir_; + float threshold_ = 0.5; + float max_det_results_ = 5; + std::vector image_shape_ = {3, 640, 640}; + std::vector label_list_; + bool ir_optim_ = true; + bool det_permute_ = true; + bool det_postprocess_ = true; + int min_subgraph_size_ = 30; + bool use_dynamic_shape_ = false; + int trt_min_shape_ = 1; + int trt_max_shape_ = 1280; + int trt_opt_shape_ = 640; + bool trt_calib_mode_ = false; + + // Preprocess image and copy data to input buffer + void Preprocess(const cv::Mat &image_mat); + + // Postprocess result + void Postprocess(const std::vector mats, + std::vector *result, std::vector bbox_num, + bool is_rbox); + + std::shared_ptr predictor_; + Preprocessor preprocessor_; + ImageBlob inputs_; + std::vector output_data_; + std::vector out_bbox_num_data_; +}; } // namespace Detection diff --git a/deploy/lite_shitu/generate_json_config.py b/deploy/lite_shitu/generate_json_config.py index 37d06c47e686daf5335dbbf1a193658c4ac20ac3..642dfcd9d6a46e2894ec0f01f0914a5347bc8d72 100644 --- a/deploy/lite_shitu/generate_json_config.py +++ b/deploy/lite_shitu/generate_json_config.py @@ -95,7 +95,7 @@ def main(): config_json["Global"]["det_model_path"] = args.det_model_path config_json["Global"]["rec_model_path"] = args.rec_model_path config_json["Global"]["rec_label_path"] = args.rec_label_path - config_json["Global"]["label_list"] = config_yaml["Global"]["labe_list"] + config_json["Global"]["label_list"] = config_yaml["Global"]["label_list"] config_json["Global"]["rec_nms_thresold"] = config_yaml["Global"][ "rec_nms_thresold"] config_json["Global"]["max_det_results"] = config_yaml["Global"][ diff --git a/deploy/python/predict_det.py b/deploy/python/predict_det.py index e4e0a24a6dbc6c62f82810c865096f768ebd182b..37a7bf5018c3b5dc78e897b532303f70b0d3957d 100644 --- a/deploy/python/predict_det.py +++ b/deploy/python/predict_det.py @@ -128,13 +128,10 @@ class DetPredictor(Predictor): results = [] if reduce(lambda x, y: x * y, np_boxes.shape) < 6: print('[WARNNING] No object detected.') - results = np.array([]) else: - results = np_boxes - - results = self.parse_det_results(results, - self.config["Global"]["threshold"], - self.config["Global"]["labe_list"]) + results = self.parse_det_results( + np_boxes, self.config["Global"]["threshold"], + self.config["Global"]["label_list"]) return results diff --git a/docs/zh_CN/algorithm_introduction/ImageNet_models.md b/docs/zh_CN/algorithm_introduction/ImageNet_models.md index ee98de442a40fb7c37b2274b756a728f7dcfc5af..4c26ea105453e954457aca71edb66394c5037153 100644 --- a/docs/zh_CN/algorithm_introduction/ImageNet_models.md +++ b/docs/zh_CN/algorithm_introduction/ImageNet_models.md @@ -10,7 +10,7 @@ - [2.1 服务器端知识蒸馏模型](#2.1) - [2.2 移动端知识蒸馏模型](#2.2) - [2.3 Intel CPU 端知识蒸馏模型](#2.3) -- [3. PP-LCNet 系列](#3) +- [3. PP-LCNet & PP-LCNetV2 系列](#3) - [4. ResNet 系列](#4) - [5. 移动端系列](#5) - [6. SEResNeXt 与 Res2Net 系列](#6) @@ -106,9 +106,9 @@ -## 3. PP-LCNet 系列 [[28](#ref28)] +## 3. PP-LCNet & PP-LCNetV2 系列 [[28](#ref28)] -PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md)。 +PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md),[PP-LCNetV2 系列模型文档](../models/PP-LCNetV2.md)。 | 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6148 time(ms)
bs=1 | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 | |:--:|:--:|:--:|:--:|----|----|----|:--:| @@ -121,6 +121,10 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该 | PPLCNet_x2_0 |0.7518 | 0.9227 | 20.1667 | 590 | 6.54 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar) | | PPLCNet_x2_5 |0.7660 | 0.9300 | 29.595 | 906 | 9.04 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar) | +| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6271C
bs=1
OpenVINO 2021.4.2
time(ms) | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 | +|:--:|:--:|:--:|:--:|----|----|----|:--:| +| PPLCNetV2_base | 77.04 | 93.27 | 4.32 | 604 | 6.6 | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar | + ## 4. ResNet 系列 [[1](#ref1)] diff --git a/docs/zh_CN/models/PP-HGNet.md b/docs/zh_CN/models/PP-HGNet.md new file mode 100644 index 0000000000000000000000000000000000000000..a0216edea954fe927bf8e916169c2062ff94998d --- /dev/null +++ b/docs/zh_CN/models/PP-HGNet.md @@ -0,0 +1,24 @@ +# PP-HGNet 系列 +--- +## 目录 + +* [1. 概述](#1) +* [2. 精度、FLOPs 和参数量](#2) + + + +## 1. 概述 + +PP-HGNet是百度自研的一个在 GPU 端上高性能的网络,该网络在 VOVNet 的基础上融合了 ResNet_vd、PPLCNet 的优点,使用了可学习的下采样层,组合成了一个在 GPU 设备上速度快、精度高的网络,超越其他 GPU 端 SOTA 模型。 + + + +## 2.精度、FLOPs 和参数量 + +| Models | Top1 | Top5 | FLOPs
(G) | Params
(M) | +|:--:|:--:|:--:|:--:|:--:| +| PPHGNet_tiny | 79.83 | 95.04 | 4.54 | 14.75 | +| PPHGNet_tiny_ssld | 81.95 | 96.12 | 4.54 | 14.75 | +| PPHGNet_small | 81.51 | 95.82 | 8.53 | 24.38 | + +关于 Inference speed 等信息,敬请期待。 diff --git a/docs/zh_CN/models/PP-LCNetV2.md b/docs/zh_CN/models/PP-LCNetV2.md new file mode 100644 index 0000000000000000000000000000000000000000..7563574694696247d553669e363df68fa00148dc --- /dev/null +++ b/docs/zh_CN/models/PP-LCNetV2.md @@ -0,0 +1,15 @@ +# PP-LCNetV2 系列 + +--- + +## 概述 + +PP-LCNetV2 是在 [PP-LCNet 系列模型](./PP-LCNet.md)的基础上,所提出的针对 Intel CPU 硬件平台设计的计算机视觉骨干网络,该模型更为 + +在不使用额外数据的前提下,PPLCNetV2_base 模型在图像分类 ImageNet 数据集上能够取得超过 77% 的 Top1 Acc,同时在 Intel CPU 平台仅有 4.4 ms 以下的延迟,如下表所示,其中延时测试基于 Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz 硬件平台,OpenVINO 2021.4.2推理平台。 + +| Model | Params(M) | FLOPs(M) | Top-1 Acc(\%) | Top-5 Acc(\%) | Latency(ms) | +|-------|-----------|----------|---------------|---------------|-------------| +| PPLCNetV2_base | 6.6 | 604 | 77.04 | 93.27 | 4.32 | + +关于 PP-LCNetV2 系列模型的更多信息,敬请关注。 diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index b62b5a64df348e257beee174eeb5bff1007f1d3e..7cbef1b10cde79a0db5ef0c8c28f63880cb11119 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -22,7 +22,9 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19 from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3 from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5 +from ppcls.arch.backbone.legendary_models.pp_lcnet_v2 import PPLCNetV2_base from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0 +from ppcls.arch.backbone.legendary_models.pp_hgnet import PPHGNet_tiny, PPHGNet_small, PPHGNet_base from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d diff --git a/ppcls/arch/backbone/legendary_models/pp_hgnet.py b/ppcls/arch/backbone/legendary_models/pp_hgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0412dfb210c7dc44bc98854dbb96fca526ab1f --- /dev/null +++ b/ppcls/arch/backbone/legendary_models/pp_hgnet.py @@ -0,0 +1,372 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import KaimingNormal, Constant +from paddle.nn import Conv2D, BatchNorm2D, ReLU, AdaptiveAvgPool2D, MaxPool2D +from paddle.regularizer import L2Decay +from paddle import ParamAttr + +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "PPHGNet_tiny": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams", + "PPHGNet_small": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams" +} + +__all__ = list(MODEL_URLS.keys()) + +kaiming_normal_ = KaimingNormal() +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class ConvBNAct(TheseusLayer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + use_act=True): + super().__init__() + self.use_act = use_act + self.conv = Conv2D( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2, + groups=groups, + bias_attr=False) + self.bn = BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + if self.use_act: + self.act = ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_act: + x = self.act(x) + return x + + +class ESEModule(TheseusLayer): + def __init__(self, channels): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv = Conv2D( + in_channels=channels, + out_channels=channels, + kernel_size=1, + stride=1, + padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv(x) + x = self.sigmoid(x) + return paddle.multiply(x=identity, y=x) + + +class HG_Block(TheseusLayer): + def __init__( + self, + in_channels, + mid_channels, + out_channels, + layer_num, + identity=False, ): + super().__init__() + self.identity = identity + + self.layers = nn.LayerList() + self.layers.append( + ConvBNAct( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=3, + stride=1)) + for _ in range(layer_num - 1): + self.layers.append( + ConvBNAct( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=3, + stride=1)) + + # feature aggregation + total_channels = in_channels + layer_num * mid_channels + self.aggregation_conv = ConvBNAct( + in_channels=total_channels, + out_channels=out_channels, + kernel_size=1, + stride=1) + self.att = ESEModule(out_channels) + + def forward(self, x): + identity = x + output = [] + output.append(x) + for layer in self.layers: + x = layer(x) + output.append(x) + x = paddle.concat(output, axis=1) + x = self.aggregation_conv(x) + x = self.att(x) + if self.identity: + x += identity + return x + + +class HG_Stage(TheseusLayer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + block_num, + layer_num, + downsample=True): + super().__init__() + self.downsample = downsample + if downsample: + self.downsample = ConvBNAct( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=2, + groups=in_channels, + use_act=False) + + blocks_list = [] + blocks_list.append( + HG_Block( + in_channels, + mid_channels, + out_channels, + layer_num, + identity=False)) + for _ in range(block_num - 1): + blocks_list.append( + HG_Block( + out_channels, + mid_channels, + out_channels, + layer_num, + identity=True)) + self.blocks = nn.Sequential(*blocks_list) + + def forward(self, x): + if self.downsample: + x = self.downsample(x) + x = self.blocks(x) + return x + + +class PPHGNet(TheseusLayer): + """ + PPHGNet + Args: + stem_channels: list. Stem channel list of PPHGNet. + stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc. + layer_num: int. Number of layers of HG_Block. + use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer. + class_expand: int=2048. Number of channels for the last 1x1 convolutional layer. + dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used. + class_num: int=1000. The number of classes. + Returns: + model: nn.Layer. Specific PPHGNet model depends on args. + """ + def __init__(self, + stem_channels, + stage_config, + layer_num, + use_last_conv=True, + class_expand=2048, + dropout_prob=0.0, + class_num=1000): + super().__init__() + self.use_last_conv = use_last_conv + self.class_expand = class_expand + + # stem + stem_channels.insert(0, 3) + self.stem = nn.Sequential(* [ + ConvBNAct( + in_channels=stem_channels[i], + out_channels=stem_channels[i + 1], + kernel_size=3, + stride=2 if i == 0 else 1) for i in range( + len(stem_channels) - 1) + ]) + self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + # stages + self.stages = nn.LayerList() + for k in stage_config: + in_channels, mid_channels, out_channels, block_num, downsample = stage_config[ + k] + self.stages.append( + HG_Stage(in_channels, mid_channels, out_channels, block_num, + layer_num, downsample)) + + self.avg_pool = AdaptiveAvgPool2D(1) + if self.use_last_conv: + self.last_conv = Conv2D( + in_channels=out_channels, + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.act = nn.ReLU() + self.dropout = nn.Dropout( + p=dropout_prob, mode="downscale_in_infer") + + self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) + self.fc = nn.Linear(self.class_expand + if self.use_last_conv else out_channels, class_num) + + self._init_weights() + + def _init_weights(self): + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2D)): + ones_(m.weight) + zeros_(m.bias) + elif isinstance(m, nn.Linear): + zeros_(m.bias) + + def forward(self, x): + x = self.stem(x) + x = self.pool(x) + + for stage in self.stages: + x = stage(x) + + x = self.avg_pool(x) + if self.use_last_conv: + x = self.last_conv(x) + x = self.act(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs): + """ + PPHGNet_tiny + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPHGNet_tiny` model depends on args. + """ + stage_config = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [96, 96, 224, 1, False], + "stage2": [224, 128, 448, 1, True], + "stage3": [448, 160, 512, 2, True], + "stage4": [512, 192, 768, 1, True], + } + + model = PPHGNet( + stem_channels=[48, 48, 96], + stage_config=stage_config, + layer_num=5, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_tiny"], use_ssld) + return model + + +def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): + """ + PPHGNet_small + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPHGNet_small` model depends on args. + """ + stage_config = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [128, 128, 256, 1, False], + "stage2": [256, 160, 512, 1, True], + "stage3": [512, 192, 768, 2, True], + "stage4": [768, 224, 1024, 1, True], + } + + model = PPHGNet( + stem_channels=[64, 64, 128], + stage_config=stage_config, + layer_num=6, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_small"], use_ssld) + return model + + +def PPHGNet_base(pretrained=False, use_ssld=False, **kwargs): + """ + PPHGNet_base + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPHGNet_base` model depends on args. + """ + stage_config = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [160, 192, 320, 1, False], + "stage2": [320, 224, 640, 2, True], + "stage3": [640, 256, 960, 3, True], + "stage4": [960, 288, 1280, 2, True], + } + + model = PPHGNet( + stem_channels=[96, 96, 160], + stage_config=stage_config, + layer_num=7, + dropout_prob=0.2, + **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_base"], use_ssld) + return model diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..459d84275ac63af54fb9ad10af2bcf2f7759052d --- /dev/null +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py @@ -0,0 +1,352 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "PPLCNetV2_base": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +NET_CONFIG = { + # in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut + "stage1": [64, 3, False, False, False, False], + "stage2": [128, 3, False, False, False, False], + "stage3": [256, 5, True, True, True, False], + "stage4": [512, 5, False, True, False, True], +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(TheseusLayer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + use_act=True): + super().__init__() + self.use_act = use_act + self.conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + if self.use_act: + self.act = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_act: + x = self.act(x) + return x + + +class SEModule(TheseusLayer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +class RepDepthwiseSeparable(TheseusLayer): + def __init__(self, + in_channels, + out_channels, + stride, + dw_size=3, + split_pw=False, + use_rep=False, + use_se=False, + use_shortcut=False): + super().__init__() + self.is_repped = False + + self.dw_size = dw_size + self.split_pw = split_pw + self.use_rep = use_rep + self.use_se = use_se + self.use_shortcut = True if use_shortcut and stride == 1 and in_channels == out_channels else False + + if self.use_rep: + self.dw_conv_list = nn.LayerList() + for kernel_size in range(self.dw_size, 0, -2): + if kernel_size == 1 and stride != 1: + continue + dw_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + use_act=False) + self.dw_conv_list.append(dw_conv) + self.dw_conv = nn.Conv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=dw_size, + stride=stride, + padding=(dw_size - 1) // 2, + groups=in_channels) + else: + self.dw_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=dw_size, + stride=stride, + groups=in_channels) + + self.act = nn.ReLU() + + if use_se: + self.se = SEModule(in_channels) + + if self.split_pw: + pw_ratio = 0.5 + self.pw_conv_1 = ConvBNLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=int(out_channels * pw_ratio), + stride=1) + self.pw_conv_2 = ConvBNLayer( + in_channels=int(out_channels * pw_ratio), + kernel_size=1, + out_channels=out_channels, + stride=1) + else: + self.pw_conv = ConvBNLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=out_channels, + stride=1) + + def forward(self, x): + if self.use_rep: + input_x = x + if self.is_repped: + x = self.act(self.dw_conv(x)) + else: + y = self.dw_conv_list[0](x) + for dw_conv in self.dw_conv_list[1:]: + y += dw_conv(x) + x = self.act(y) + else: + x = self.dw_conv(x) + + if self.use_se: + x = self.se(x) + if self.split_pw: + x = self.pw_conv_1(x) + x = self.pw_conv_2(x) + else: + x = self.pw_conv(x) + if self.use_shortcut: + x = x + input_x + return x + + def rep(self): + if self.use_rep: + self.is_repped = True + kernel, bias = self._get_equivalent_kernel_bias() + self.dw_conv.weight.set_value(kernel) + self.dw_conv.bias.set_value(bias) + + def _get_equivalent_kernel_bias(self): + kernel_sum = 0 + bias_sum = 0 + for dw_conv in self.dw_conv_list: + kernel, bias = self._fuse_bn_tensor(dw_conv) + kernel = self._pad_tensor(kernel, to_size=self.dw_size) + kernel_sum += kernel + bias_sum += bias + return kernel_sum, bias_sum + + def _fuse_bn_tensor(self, branch): + kernel = branch.conv.weight + running_mean = branch.bn._mean + running_var = branch.bn._variance + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn._epsilon + std = (running_var + eps).sqrt() + t = (gamma / std).reshape((-1, 1, 1, 1)) + return kernel * t, beta - running_mean * gamma / std + + def _pad_tensor(self, tensor, to_size): + from_size = tensor.shape[-1] + if from_size == to_size: + return tensor + pad = (to_size - from_size) // 2 + return F.pad(tensor, [pad, pad, pad, pad]) + + +class PPLCNetV2(TheseusLayer): + def __init__(self, + scale, + depths, + class_num=1000, + dropout_prob=0, + use_last_conv=True, + class_expand=1280): + super().__init__() + self.scale = scale + self.use_last_conv = use_last_conv + self.class_expand = class_expand + + self.stem = nn.Sequential(* [ + ConvBNLayer( + in_channels=3, + kernel_size=3, + out_channels=make_divisible(32 * scale), + stride=2), RepDepthwiseSeparable( + in_channels=make_divisible(32 * scale), + out_channels=make_divisible(64 * scale), + stride=1, + dw_size=3) + ]) + + # stages + self.stages = nn.LayerList() + for depth_idx, k in enumerate(NET_CONFIG): + in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut = NET_CONFIG[ + k] + self.stages.append( + nn.Sequential(* [ + RepDepthwiseSeparable( + in_channels=make_divisible((in_channels if i == 0 else + in_channels * 2) * scale), + out_channels=make_divisible(in_channels * 2 * scale), + stride=2 if i == 0 else 1, + dw_size=kernel_size, + split_pw=split_pw, + use_rep=use_rep, + use_se=use_se, + use_shortcut=use_shortcut) + for i in range(depths[depth_idx]) + ])) + + self.avg_pool = AdaptiveAvgPool2D(1) + + if self.use_last_conv: + self.last_conv = Conv2D( + in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 * + scale), + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.act = nn.ReLU() + self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") + + self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) + in_features = self.class_expand if self.use_last_conv else NET_CONFIG[ + "stage4"][0] * 2 * scale + self.fc = Linear(in_features, class_num) + + def forward(self, x): + x = self.stem(x) + for stage in self.stages: + x = stage(x) + x = self.avg_pool(x) + if self.use_last_conv: + x = self.last_conv(x) + x = self.act(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs): + """ + PPLCNetV2_base + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPLCNetV2_base` model depends on args. + """ + model = PPLCNetV2( + scale=1.0, depths=[2, 2, 6, 2], dropout_prob=0.2, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) + return model diff --git a/ppcls/arch/backbone/model_zoo/repvgg.py b/ppcls/arch/backbone/model_zoo/repvgg.py index 8ff662a7f88086abeee6b7f6e0260d2d3b3cd0c1..12f65549fad60adae6a412d8adb05f9846922c81 100644 --- a/ppcls/arch/backbone/model_zoo/repvgg.py +++ b/ppcls/arch/backbone/model_zoo/repvgg.py @@ -124,13 +124,7 @@ class RepVGGBlock(nn.Layer): groups=groups) def forward(self, inputs): - if not self.training and not self.is_repped: - self.rep() - self.is_repped = True - if self.training and self.is_repped: - self.is_repped = False - - if not self.training: + if self.is_repped: return self.nonlinearity(self.rbr_reparam(inputs)) if self.rbr_identity is None: @@ -154,6 +148,7 @@ class RepVGGBlock(nn.Layer): kernel, bias = self.get_equivalent_kernel_bias() self.rbr_reparam.weight.set_value(kernel) self.rbr_reparam.bias.set_value(bias) + self.is_repped = True def get_equivalent_kernel_bias(self): kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) diff --git a/ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml b/ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eabccd4b712ab48886c74caf6b784b4c193f6913 --- /dev/null +++ b/ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml @@ -0,0 +1,164 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 600 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + use_dali: False + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +# model architecture +Arch: + name: PPHGNet_small + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.5 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00004 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m7-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.2 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 16 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 236 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 16 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 236 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml b/ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e423c866b131aefda13b0186eca7ac27d3c84733 --- /dev/null +++ b/ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml @@ -0,0 +1,164 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 600 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + use_dali: False + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +# model architecture +Arch: + name: PPHGNet_tiny + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.5 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00004 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m7-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.2 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 16 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 232 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 16 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 232 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml b/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..640833938bd81d8dd24c8bdd0ae1de86d8697a10 --- /dev/null +++ b/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml @@ -0,0 +1,133 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 480 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + +# model architecture +Arch: + name: PPLCNetV2_base + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.8 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00004 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: MultiScaleDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + # support to specify width and height respectively: + # scales: [(160,160), (192,192), (224,224) (288,288) (320,320)] + sampler: + name: MultiScaleSampler + scales: [160, 192, 224, 288, 320] + # first_bs: batch size for the first image resolution in the scales list + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: 500 + divided_factor: 32 + is_training: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml b/ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml new file mode 100644 index 0000000000000000000000000000000000000000..530e7507519ed37dd1126633738c903769fe697e --- /dev/null +++ b/ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml @@ -0,0 +1,138 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: MobileNetV1 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [30, 60, 90] + values: [0.1, 0.01, 0.001, 0.0001] + regularizer: + name: 'L2' + coeff: 0.00003 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: MultiScaleDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + # support to specify width and height respectively: + # scales: [(160,160), (192,192), (224,224) (288,288) (320,320)] + sampler: + name: MultiScaleSampler + scales: [160, 192, 224, 288, 320] + # first_bs: batch size for the first image resolution in the scales list + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: 64 + divided_factor: 32 + is_training: True + + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml b/ppcls/configs/reid/strong_baseline/baseline.yaml similarity index 96% rename from ppcls/configs/Pedestrian/strong_baseline_baseline.yaml rename to ppcls/configs/reid/strong_baseline/baseline.yaml index bc022158b0e80603f65e3d9e7aa4ceaefe9bb50c..35980206b19bab76f46df54e143adaecc1f4b566 100644 --- a/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml +++ b/ppcls/configs/reid/strong_baseline/baseline.yaml @@ -79,6 +79,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - RandFlipImage: flip_code: 1 @@ -110,6 +111,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - ToTensor: - Normalize: @@ -134,6 +136,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - ToTensor: - Normalize: diff --git a/ppcls/configs/Pedestrian/strong_baseline_m1.yaml b/ppcls/configs/reid/strong_baseline/softmax_triplet.yaml similarity index 97% rename from ppcls/configs/Pedestrian/strong_baseline_m1.yaml rename to ppcls/configs/reid/strong_baseline/softmax_triplet.yaml index 23a9b9d5e80459c9e67b08ab2a6a6ddd66e3e03e..6f9cd955626316fe5267e3f9289b93b4317f736f 100644 --- a/ppcls/configs/Pedestrian/strong_baseline_m1.yaml +++ b/ppcls/configs/reid/strong_baseline/softmax_triplet.yaml @@ -91,6 +91,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - RandFlipImage: flip_code: 1 @@ -128,6 +129,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - ToTensor: - Normalize: @@ -152,6 +154,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - ToTensor: - Normalize: diff --git a/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml b/ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml similarity index 97% rename from ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml rename to ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml index 97b6fda9fc27c3a983203312dd3a483802d91c1e..22af5e516ca4b9945bc8413ed56c67c972b48609 100644 --- a/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml +++ b/ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml @@ -102,6 +102,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - RandFlipImage: flip_code: 1 @@ -139,6 +140,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - ToTensor: - Normalize: @@ -163,6 +165,7 @@ DataLoader: - ResizeImage: size: [128, 256] return_numpy: False + interpolation: 'bilinear' backend: "pil" - ToTensor: - Normalize: diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 9722bfb85a9a93d007507174ec17b1b95738270c..9fc4d760be545ffa93652c80d285e17ad0c8ae57 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -28,12 +28,15 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.mix_dataset import MixDataset +from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 + # sampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.mix_sampler import MixSampler +from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler from ppcls.data import preprocess from ppcls.data.preprocess import transform diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 271a8f5cbfa164dbd6803312cf2d468f8c9bdc82..2b1d92b76bd202e36086f21a3a092c3673277690 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -5,6 +5,8 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.mix_dataset import MixDataset +from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.mix_sampler import MixSampler +from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index a15c231568a97fd607f2ada4f5f6e81fa084cc62..a340a946c921bedd475531eb3bd9172f49a99e1e 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -230,7 +230,7 @@ def dali_dataloader(config, mode, device, seed=None): lower = ratio[0] upper = ratio[1] - if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env: + if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env: shard_id = int(env['PADDLE_TRAINER_ID']) num_shards = int(env['PADDLE_TRAINERS_NUM']) device_id = int(env['FLAGS_selected_gpus']) @@ -282,7 +282,7 @@ def dali_dataloader(config, mode, device, seed=None): else: resize_shorter = transforms["ResizeImage"].get("resize_short", 256) crop = transforms["CropImage"]["size"] - if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler": + if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env and sampler_name == "DistributedBatchSampler": shard_id = int(env['PADDLE_TRAINER_ID']) num_shards = int(env['PADDLE_TRAINERS_NUM']) device_id = int(env['FLAGS_selected_gpus']) diff --git a/ppcls/data/dataloader/multi_scale_dataset.py b/ppcls/data/dataloader/multi_scale_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ddddf35ef5feca9817e380025d85a34b3989f12f --- /dev/null +++ b/ppcls/data/dataloader/multi_scale_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os + +from paddle.io import Dataset +from paddle.vision import transforms +import cv2 +import warnings + +from ppcls.data import preprocess +from ppcls.data.preprocess import transform +from ppcls.data.preprocess.ops.operators import DecodeImage +from ppcls.utils import logger +from ppcls.data.dataloader.common_dataset import create_operators + + +class MultiScaleDataset(Dataset): + def __init__( + self, + image_root, + cls_label_path, + transform_ops=None, ): + self._img_root = image_root + self._cls_path = cls_label_path + self.transform_ops = transform_ops + self.images = [] + self.labels = [] + self._load_anno() + self.has_crop_flag = 1 + + def _load_anno(self, seed=None): + assert os.path.exists(self._cls_path) + assert os.path.exists(self._img_root) + self.images = [] + self.labels = [] + + with open(self._cls_path) as fd: + lines = fd.readlines() + if seed is not None: + np.random.RandomState(seed).shuffle(lines) + for l in lines: + l = l.strip().split(" ") + self.images.append(os.path.join(self._img_root, l[0])) + self.labels.append(np.int64(l[1])) + assert os.path.exists(self.images[-1]) + + def __getitem__(self, properties): + # properites is a tuple, contains (width, height, index) + img_width = properties[0] + img_height = properties[1] + index = properties[2] + has_crop = False + if self.transform_ops: + for i in range(len(self.transform_ops)): + op = self.transform_ops[i] + resize_op = ['RandCropImage', 'ResizeImage', 'CropImage'] + for resize in resize_op: + if resize in op: + if self.has_crop_flag: + logger.warning( + "Multi scale dataset will crop image according to the multi scale resolution" + ) + self.transform_ops[i][resize] = { + 'size': (img_width, img_height) + } + has_crop = True + self.has_crop_flag = 0 + if has_crop == False: + logger.error("Multi scale dateset requests RandCropImage") + raise RuntimeError("Multi scale dateset requests RandCropImage") + self._transform_ops = create_operators(self.transform_ops) + + try: + with open(self.images[index], 'rb') as f: + img = f.read() + if self._transform_ops: + img = transform(img, self._transform_ops) + img = img.transpose((2, 0, 1)) + return (img, self.labels[index]) + + except Exception as ex: + logger.error("Exception occured when parse line: {} with msg: {}". + format(self.images[index], ex)) + rnd_idx = np.random.randint(self.__len__()) + return self.__getitem__(rnd_idx) + + def __len__(self): + return len(self.images) + + @property + def class_num(self): + return len(set(self.labels)) diff --git a/ppcls/data/dataloader/multi_scale_sampler.py b/ppcls/data/dataloader/multi_scale_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..57b42b307dfb223c2ab434a89fc6c56b4e1e4a5c --- /dev/null +++ b/ppcls/data/dataloader/multi_scale_sampler.py @@ -0,0 +1,132 @@ +from paddle.io import Sampler +import paddle.distributed as dist + +import math +import random +import numpy as np + +from ppcls import data + + +class MultiScaleSampler(Sampler): + def __init__(self, + data_source, + scales, + first_bs, + divided_factor=32, + is_training=True, + seed=None): + """ + multi scale samper + Args: + data_source(dataset) + scales(list): several scales for image resolution + first_bs(int): batch size for the first scale in scales + divided_factor(int): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor. + is_training(boolean): mode + """ + # min. and max. spatial dimensions + self.data_source = data_source + self.n_data_samples = len(self.data_source) + + if isinstance(scales[0], tuple): + width_dims = [i[0] for i in scales] + height_dims = [i[1] for i in scales] + elif isinstance(scales[0], int): + width_dims = scales + height_dims = scales + base_im_w = width_dims[0] + base_im_h = height_dims[0] + base_batch_size = first_bs + + # Get the GPU and node related information + num_replicas = dist.get_world_size() + rank = dist.get_rank() + # adjust the total samples to avoid batch dropping + num_samples_per_replica = int( + math.ceil(self.n_data_samples * 1.0 / num_replicas)) + img_indices = [idx for idx in range(self.n_data_samples)] + + self.shuffle = False + if is_training: + # compute the spatial dimensions and corresponding batch size + # ImageNet models down-sample images by a factor of 32. + # Ensure that width and height dimensions are multiples are multiple of 32. + width_dims = [ + int((w // divided_factor) * divided_factor) for w in width_dims + ] + height_dims = [ + int((h // divided_factor) * divided_factor) + for h in height_dims + ] + + img_batch_pairs = list() + base_elements = base_im_w * base_im_h * base_batch_size + for (h, w) in zip(height_dims, width_dims): + batch_size = int(max(1, (base_elements / (h * w)))) + img_batch_pairs.append((w, h, batch_size)) + self.img_batch_pairs = img_batch_pairs + self.shuffle = True + else: + self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)] + + self.img_indices = img_indices + self.n_samples_per_replica = num_samples_per_replica + self.epoch = 0 + self.rank = rank + self.num_replicas = num_replicas + self.seed = seed + self.batch_list = [] + self.current = 0 + indices_rank_i = self.img_indices[self.rank:len(self.img_indices): + self.num_replicas] + while self.current < self.n_samples_per_replica: + curr_w, curr_h, curr_bsz = random.choice(self.img_batch_pairs) + + end_index = min(self.current + curr_bsz, + self.n_samples_per_replica) + + batch_ids = indices_rank_i[self.current:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != curr_bsz: + batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] + self.current += curr_bsz + + if len(batch_ids) > 0: + batch = [curr_w, curr_h, len(batch_ids)] + self.batch_list.append(batch) + self.length = len(self.batch_list) + + def __iter__(self): + if self.shuffle: + if self.seed is not None: + random.seed(self.seed) + else: + random.seed(self.epoch) + random.shuffle(self.img_indices) + random.shuffle(self.img_batch_pairs) + indices_rank_i = self.img_indices[self.rank:len(self.img_indices): + self.num_replicas] + else: + indices_rank_i = self.img_indices[self.rank:len(self.img_indices): + self.num_replicas] + + start_index = 0 + for batch_tuple in self.batch_list: + curr_w, curr_h, curr_bsz = batch_tuple + end_index = min(start_index + curr_bsz, self.n_samples_per_replica) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != curr_bsz: + batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] + start_index += curr_bsz + + if len(batch_ids) > 0: + batch = [(curr_w, curr_h, b_id) for b_id in batch_ids] + yield batch + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def __len__(self): + return self.length diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 5b5c4da8a6500ab90c31f33097075db5f8ee5f89..675e92a6d7fa6325629d3b0d94ccf3de314db2b8 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -452,6 +452,12 @@ class Engine(object): self.config["Global"]["pretrained_model"]) model.eval() + + # for rep nets + for layer in self.model.sublayers(): + if hasattr(layer, "rep"): + layer.rep() + save_path = os.path.join(self.config["Global"]["save_inference_dir"], "inference") if model.quanter: diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 093255379cd35875fbaf06282e391017bf7f14a3..4e27f12c1d4830f2f16580bfa976cf3ace78d934 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -116,9 +116,8 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): load_distillation_model(net, pretrained_model) else: # common load load_dygraph_pretrain(net, path=pretrained_model) - logger.info( - logger.coloring("Finish load pretrained model from {}".format( - pretrained_model), "HEADER")) + logger.info("Finish load pretrained model from {}".format( + pretrained_model)) def save_model(net, diff --git a/test_tipc/config/PPHGNet/PPHGNet_small_train_infer_python.txt b/test_tipc/config/PPHGNet/PPHGNet_small_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..e787bb0521500ac257a94ed30e892eb4a016a738 --- /dev/null +++ b/test_tipc/config/PPHGNet/PPHGNet_small_train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:PPHGNet_small +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=236 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/config/PPHGNet/PPHGNet_tiny_train_infer_python.txt b/test_tipc/config/PPHGNet/PPHGNet_tiny_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..546b9fa1ef5de70730e9e4a6425c23bf729ef017 --- /dev/null +++ b/test_tipc/config/PPHGNet/PPHGNet_tiny_train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:PPHGNet_tiny +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=232 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c2806f27885e8fc3d31233b700ac9120fce6888 --- /dev/null +++ b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:PPLCNetV2_base +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.first_bs:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml -o Global.seed=1234 -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}]