提交 4d33efee 编写于 作者: D dongshuilong

fix slim bugs

上级 70a1fb9d
......@@ -18,6 +18,7 @@ import os
import sys
import paddle
from paddle import nn
import numpy as np
import paddleslim
from paddle.jit import to_static
......@@ -37,7 +38,8 @@ from ppcls.utils.config import print_config
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.data import build_dataloader
from ppcls.arch import apply_to_static
from ppcls.arch import build_model
from ppcls.arch import build_model, RecModel, DistillationModel
from ppcls.arch.gears.identity_head import IdentityHead
quant_config = {
# weight preprocess type, default is None and no preprocessing is performed.
......@@ -63,6 +65,49 @@ quant_config = {
}
class ExportModel(nn.Layer):
"""
ExportModel: add softmax onto the model
"""
def __init__(self, config, model):
super().__init__()
self.base_model = model
# we should choose a final model to export
if isinstance(self.base_model, DistillationModel):
self.infer_model_name = config["infer_model_name"]
else:
self.infer_model_name = None
self.infer_output_key = config.get("infer_output_key", None)
if self.infer_output_key == "features" and isinstance(self.base_model,
RecModel):
self.base_model.head = IdentityHead()
if config.get("infer_add_softmax", True):
self.softmax = nn.Softmax(axis=-1)
else:
self.softmax = None
def eval(self):
self.training = False
for layer in self.sublayers():
layer.training = False
layer.eval()
def forward(self, x):
x = self.base_model(x)
if isinstance(x, list):
x = x[0]
if self.infer_model_name is not None:
x = x[self.infer_model_name]
if self.infer_output_key is not None:
x = x[self.infer_output_key]
if self.softmax is not None:
x = self.softmax(x)
return x
class Trainer_slim(Trainer):
def __init__(self, config, mode="train"):
......@@ -195,11 +240,12 @@ class Trainer_slim(Trainer):
raise RuntimeError(
"The best_model or pretraine_model should exist to generate inference model"
)
model = ExportModel(self.config["Arch"], self.model)
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if self.quanter:
self.quanter.save_quantized_model(
self.model,
model,
save_path,
input_spec=[
paddle.static.InputSpec(
......@@ -208,7 +254,7 @@ class Trainer_slim(Trainer):
])
else:
model = to_static(
self.model,
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
......
......@@ -14,7 +14,7 @@ Global:
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# for quantalization or prune model
# for quantization or prune model
Slim:
## for prune
prune:
......
......@@ -16,7 +16,7 @@ Global:
# for quantalization or prune model
Slim:
## for quantalization
## for quantization
quant:
name: pact
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 160
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
eval_mode: "retrieval"
# for quantizaiton or prune model
Slim:
## for prune
prune:
name: fpgm
pruned_ratio: 0.3
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
Head:
name: "ArcMargin"
embedding_size: 512
class_num: 30671
margin: 0.15
scale: 32
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
- SupConLoss:
weight: 1.0
views: 2
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.01
last_epoch: -1
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "VeriWild"
image_root: "./dataset/VeRI-Wild/images/"
cls_label_path: "./dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- RandFlipImage:
flip_code: 1
- AugMix:
prob: 0.5
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0., 0., 0.]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
loader:
num_workers: 6
use_shared_memory: True
Eval:
Query:
dataset:
name: "VeriWild"
image_root: "./dataset/VeRI-Wild/images"
cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Gallery:
dataset:
name: "VeriWild"
image_root: "./dataset/VeRI-Wild/images"
cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册