From 18e1cf040bb2cb27c0fe367217cba9ee6aa50663 Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Thu, 16 Sep 2021 12:17:02 +0000 Subject: [PATCH] fix pact bug for circlemargin arcmargin cosmargin --- .gitignore | 3 +- ppcls/arch/gears/arcmargin.py | 19 +- ppcls/arch/gears/circlemargin.py | 17 +- ppcls/arch/gears/cosmargin.py | 18 +- ppcls/configs/Vehicle/ResNet50.yaml | 7 +- ppcls/configs/Vehicle/ResNet50_ReID.yaml | 2 +- .../slim/ResNet50_vehicle_cls_prune.yaml | 136 +++++++++++++++ .../ResNet50_vehicle_cls_quantization.yaml | 135 +++++++++++++++ .../slim/ResNet50_vehicle_reid_prune.yaml | 2 +- .../ResNet50_vehicle_reid_quantization.yaml | 162 ++++++++++++++++++ 10 files changed, 461 insertions(+), 40 deletions(-) create mode 100644 ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml create mode 100644 ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml create mode 100644 ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml diff --git a/.gitignore b/.gitignore index 8f00d034..f56c23c0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,10 +3,11 @@ __pycache__/ *.sw* */workerlog* checkpoints/ -output/ +output*/ pretrained/ .ipynb_checkpoints/ *.ipynb* _build/ build/ +log/ nohup.out diff --git a/ppcls/arch/gears/arcmargin.py b/ppcls/arch/gears/arcmargin.py index bab7a356..22cc76e1 100644 --- a/ppcls/arch/gears/arcmargin.py +++ b/ppcls/arch/gears/arcmargin.py @@ -24,30 +24,25 @@ class ArcMargin(nn.Layer): margin=0.5, scale=80.0, easy_margin=False): - super(ArcMargin, self).__init__() + super().__init__() self.embedding_size = embedding_size self.class_num = class_num self.margin = margin self.scale = scale self.easy_margin = easy_margin - - weight_attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.XavierNormal()) - self.fc = nn.Linear( - self.embedding_size, - self.class_num, - weight_attr=weight_attr, - bias_attr=False) + self.weight = self.create_parameter( + shape=[self.embedding_size, self.class_num], + is_bias=False, + default_initializer=paddle.nn.initializer.XavierNormal()) def forward(self, input, label=None): input_norm = paddle.sqrt( paddle.sum(paddle.square(input), axis=1, keepdim=True)) input = paddle.divide(input, input_norm) - weight = self.fc.weight weight_norm = paddle.sqrt( - paddle.sum(paddle.square(weight), axis=0, keepdim=True)) - weight = paddle.divide(weight, weight_norm) + paddle.sum(paddle.square(self.weight), axis=0, keepdim=True)) + weight = paddle.divide(self.weight, weight_norm) cos = paddle.matmul(input, weight) if not self.training or label is None: diff --git a/ppcls/arch/gears/circlemargin.py b/ppcls/arch/gears/circlemargin.py index 87baee83..d1bce83c 100644 --- a/ppcls/arch/gears/circlemargin.py +++ b/ppcls/arch/gears/circlemargin.py @@ -26,20 +26,19 @@ class CircleMargin(nn.Layer): self.embedding_size = embedding_size self.class_num = class_num - weight_attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.XavierNormal()) - self.fc = paddle.nn.Linear( - self.embedding_size, self.class_num, weight_attr=weight_attr) + self.weight = self.create_parameter( + shape=[self.embedding_size, self.class_num], + is_bias=False, + default_initializer=paddle.nn.initializer.XavierNormal()) def forward(self, input, label): feat_norm = paddle.sqrt( paddle.sum(paddle.square(input), axis=1, keepdim=True)) input = paddle.divide(input, feat_norm) - weight = self.fc.weight weight_norm = paddle.sqrt( - paddle.sum(paddle.square(weight), axis=0, keepdim=True)) - weight = paddle.divide(weight, weight_norm) + paddle.sum(paddle.square(self.weight), axis=0, keepdim=True)) + weight = paddle.divide(self.weight, weight_norm) logits = paddle.matmul(input, weight) if not self.training or label is None: @@ -49,9 +48,9 @@ class CircleMargin(nn.Layer): alpha_n = paddle.clip(logits.detach() + self.margin, min=0.) delta_p = 1 - self.margin delta_n = self.margin - + m_hot = F.one_hot(label.reshape([-1]), num_classes=logits.shape[1]) - + logits_p = alpha_p * (logits - delta_p) logits_n = alpha_n * (logits - delta_n) pre_logits = logits_p * m_hot + logits_n * (1 - m_hot) diff --git a/ppcls/arch/gears/cosmargin.py b/ppcls/arch/gears/cosmargin.py index 378e102a..578b64c2 100644 --- a/ppcls/arch/gears/cosmargin.py +++ b/ppcls/arch/gears/cosmargin.py @@ -25,13 +25,10 @@ class CosMargin(paddle.nn.Layer): self.embedding_size = embedding_size self.class_num = class_num - weight_attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.XavierNormal()) - self.fc = nn.Linear( - self.embedding_size, - self.class_num, - weight_attr=weight_attr, - bias_attr=False) + self.weight = self.create_parameter( + shape=[self.embedding_size, self.class_num], + is_bias=False, + default_initializer=paddle.nn.initializer.XavierNormal()) def forward(self, input, label): label.stop_gradient = True @@ -40,15 +37,14 @@ class CosMargin(paddle.nn.Layer): paddle.sum(paddle.square(input), axis=1, keepdim=True)) input = paddle.divide(input, input_norm) - weight = self.fc.weight weight_norm = paddle.sqrt( - paddle.sum(paddle.square(weight), axis=0, keepdim=True)) - weight = paddle.divide(weight, weight_norm) + paddle.sum(paddle.square(self.weight), axis=0, keepdim=True)) + weight = paddle.divide(self.weight, weight_norm) cos = paddle.matmul(input, weight) if not self.training or label is None: return cos - + cos_m = cos - self.margin one_hot = paddle.nn.functional.one_hot(label, self.class_num) diff --git a/ppcls/configs/Vehicle/ResNet50.yaml b/ppcls/configs/Vehicle/ResNet50.yaml index 335222ed..ba2a0b90 100644 --- a/ppcls/configs/Vehicle/ResNet50.yaml +++ b/ppcls/configs/Vehicle/ResNet50.yaml @@ -2,7 +2,7 @@ Global: checkpoints: null pretrained_model: null - output_dir: "./output/" + output_dir: "./output_vehicle_cls/" device: "gpu" save_interval: 1 eval_during_train: True @@ -51,11 +51,8 @@ Optimizer: name: Momentum momentum: 0.9 lr: - name: MultiStepDecay + name: Cosine learning_rate: 0.01 - milestones: [30, 60, 70, 80, 90, 100, 120, 140] - gamma: 0.5 - verbose: False last_epoch: -1 regularizer: name: 'L2' diff --git a/ppcls/configs/Vehicle/ResNet50_ReID.yaml b/ppcls/configs/Vehicle/ResNet50_ReID.yaml index 333b6a24..09a04c13 100644 --- a/ppcls/configs/Vehicle/ResNet50_ReID.yaml +++ b/ppcls/configs/Vehicle/ResNet50_ReID.yaml @@ -2,7 +2,7 @@ Global: checkpoints: null pretrained_model: null - output_dir: "./output/" + output_dir: "./output_vehicle_reid/" device: "gpu" save_interval: 1 eval_during_train: True diff --git a/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml b/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml new file mode 100644 index 00000000..788a49b4 --- /dev/null +++ b/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml @@ -0,0 +1,136 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output_vehicle_cls_prune/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 160 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +Slim: + prune: + name: fpgm + pruned_ratio: 0.3 + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50_last_stage_stride1" + pretrained: True + BackboneStopLayer: + name: "adaptive_avg_pool2d_0" + Neck: + name: "VehicleNeck" + in_channels: 2048 + out_channels: 512 + Head: + name: "ArcMargin" + embedding_size: 512 + class_num: 431 + margin: 0.15 + scale: 32 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + - SupConLoss: + weight: 1.0 + views: 2 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.01 + last_epoch: -1 + regularizer: + name: 'L2' + coeff: 0.0005 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "CompCars" + image_root: "./dataset/CompCars/image/" + label_root: "./dataset/CompCars/label/" + bbox_crop: True + cls_label_path: "./dataset/CompCars/train_test_split/classification/train_label.txt" + transform_ops: + - ResizeImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AugMix: + prob: 0.5 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0., 0., 0.] + + sampler: + name: DistributedRandomIdentitySampler + batch_size: 128 + num_instances: 2 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: "CompCars" + image_root: "./dataset/CompCars/image/" + label_root: "./dataset/CompCars/label/" + cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt" + bbox_crop: True + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] + diff --git a/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml b/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml new file mode 100644 index 00000000..14905148 --- /dev/null +++ b/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml @@ -0,0 +1,135 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output_vehicle_cls_pact/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 80 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +Slim: + quant: + name: pact + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50_last_stage_stride1" + pretrained: True + BackboneStopLayer: + name: "adaptive_avg_pool2d_0" + Neck: + name: "VehicleNeck" + in_channels: 2048 + out_channels: 512 + Head: + name: "ArcMargin" + embedding_size: 512 + class_num: 431 + margin: 0.15 + scale: 32 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + - SupConLoss: + weight: 1.0 + views: 2 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.001 + last_epoch: -1 + regularizer: + name: 'L2' + coeff: 0.0005 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "CompCars" + image_root: "./dataset/CompCars/image/" + label_root: "./dataset/CompCars/label/" + bbox_crop: True + cls_label_path: "./dataset/CompCars/train_test_split/classification/train_label.txt" + transform_ops: + - ResizeImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AugMix: + prob: 0.5 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0., 0., 0.] + + sampler: + name: DistributedRandomIdentitySampler + batch_size: 128 + num_instances: 2 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: "CompCars" + image_root: "./dataset/CompCars/image/" + label_root: "./dataset/CompCars/label/" + cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt" + bbox_crop: True + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] + diff --git a/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml b/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml index 683e0145..a96ffb33 100644 --- a/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml +++ b/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml @@ -2,7 +2,7 @@ Global: checkpoints: null pretrained_model: null - output_dir: "./output/" + output_dir: "./output_fpgm/" device: "gpu" save_interval: 1 eval_during_train: True diff --git a/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml b/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml new file mode 100644 index 00000000..712a3fca --- /dev/null +++ b/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml @@ -0,0 +1,162 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output_vehicle_reid_pact/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 40 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + eval_mode: "retrieval" + +# for quantizaiton or prune model +Slim: + ## for prune + quant: + name: pact + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50_last_stage_stride1" + pretrained: True + BackboneStopLayer: + name: "adaptive_avg_pool2d_0" + Neck: + name: "VehicleNeck" + in_channels: 2048 + out_channels: 512 + Head: + name: "ArcMargin" + embedding_size: 512 + class_num: 30671 + margin: 0.15 + scale: 32 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + - SupConLoss: + weight: 1.0 + views: 2 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.001 + last_epoch: -1 + regularizer: + name: 'L2' + coeff: 0.0005 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "VeriWild" + image_root: "./dataset/VeRI-Wild/images/" + cls_label_path: "./dataset/VeRI-Wild/train_test_split/train_list_start0.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AugMix: + prob: 0.5 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0., 0., 0.] + + sampler: + name: DistributedRandomIdentitySampler + batch_size: 64 + num_instances: 2 + drop_last: False + shuffle: True + loader: + num_workers: 6 + use_shared_memory: True + Eval: + Query: + dataset: + name: "VeriWild" + image_root: "./dataset/VeRI-Wild/images" + cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 6 + use_shared_memory: True + + Gallery: + dataset: + name: "VeriWild" + image_root: "./dataset/VeRI-Wild/images" + cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 6 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] + - mAP: {} + -- GitLab