提交 185d1e1f 编写于 作者: L LDOUBLEV

fix bug

上级 a91bbd74
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: Distillation algorithm: Distillation
Models: Models:
Student: Student:
pretrained: pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false freeze_params: false
return_all_feats: false return_all_feats: false
model_type: det model_type: det
...@@ -37,7 +37,7 @@ Architecture: ...@@ -37,7 +37,7 @@ Architecture:
name: DBHead name: DBHead
k: 50 k: 50
Student2: Student2:
pretrained: pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false freeze_params: false
return_all_feats: false return_all_feats: false
model_type: det model_type: det
...@@ -55,6 +55,9 @@ Architecture: ...@@ -55,6 +55,9 @@ Architecture:
name: DBHead name: DBHead
k: 50 k: 50
Teacher: Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det model_type: det
algorithm: DB algorithm: DB
Transform: Transform:
...@@ -73,7 +76,9 @@ Loss: ...@@ -73,7 +76,9 @@ Loss:
loss_config_list: loss_config_list:
- DistillationDilaDBLoss: - DistillationDilaDBLoss:
weight: 1.0 weight: 1.0
model_name_list: ["Student", "Student2", "Teacher"] model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps key: maps
balance_loss: true balance_loss: true
main_loss_type: DiceLoss main_loss_type: DiceLoss
...@@ -81,13 +86,16 @@ Loss: ...@@ -81,13 +86,16 @@ Loss:
beta: 10 beta: 10
ohem_ratio: 3 ohem_ratio: 3
- DistillationDMLLoss: - DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: ["thrink_maps"] maps_name: ["thrink_maps"]
weight: 1.0 weight: 1.0
act: "softmax" act: "softmax"
model_name_pairs: ["Student", "Student2"] model_name_pairs: ["Student", "Student2"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
model_name_list: ["Student", "Teacher"] weight: 1.0
model_name_list: ["Student", "Student2"]
key: maps key: maps
name: DBLoss name: DBLoss
balance_loss: true balance_loss: true
...@@ -110,7 +118,7 @@ Optimizer: ...@@ -110,7 +118,7 @@ Optimizer:
factor: 0 factor: 0
PostProcess: PostProcess:
name: DistillationCTDBPostProcessCLabelDecode name: DistillationDBPostProcess
model_name: ["Student", "Student2"] model_name: ["Student", "Student2"]
key: head_out key: head_out
thresh: 0.3 thresh: 0.3
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss from .basic_loss import DMLLoss
...@@ -22,6 +24,7 @@ from .det_db_loss import DBLoss ...@@ -22,6 +24,7 @@ from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict): def _sum_loss(loss_dict):
if "loss" in loss_dict.keys(): if "loss" in loss_dict.keys():
return loss_dict return loss_dict
...@@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss): ...@@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss):
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = model_name_pairs
self.name = name self.name = name
self.maps_name = self.maps_name self.maps_name = maps_name
def _check_maps_name(self, maps_name): def _check_maps_name(self, maps_name):
if maps_name is None: if maps_name is None:
...@@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss): ...@@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss):
class DistillationDilaDBLoss(DBLoss): class DistillationDilaDBLoss(DBLoss):
def __init__(self, def __init__(self,
model_name_pairs=[], model_name_pairs=[],
key=None,
balance_loss=True, balance_loss=True,
main_loss_type='DiceLoss', main_loss_type='DiceLoss',
alpha=5, alpha=5,
...@@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss): ...@@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss):
super().__init__() super().__init__()
self.model_name_pairs = model_name_pairs self.model_name_pairs = model_name_pairs
self.name = name self.name = name
self.key = key
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
...@@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss): ...@@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss):
loss_dict[k] = bce_loss + loss_binary_maps loss_dict[k] = bce_loss + loss_binary_maps
loss_dict = _sum_loss(loss_dict) loss_dict = _sum_loss(loss_dict)
return loss return loss_dict
class DistillationDistanceLoss(DistanceLoss): class DistillationDistanceLoss(DistanceLoss):
......
...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone ...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
from .base_model import BaseModel from .base_model import BaseModel
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model, load_pretrained_params
__all__ = ['DistillationModel'] __all__ = ['DistillationModel']
...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): ...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained") pretrained = model_config.pop("pretrained")
model = BaseModel(model_config) model = BaseModel(model_config)
if pretrained is not None: if pretrained is not None:
init_model(model, path=pretrained) load_pretrained_params(model, pretrained)
if freeze_params: if freeze_params:
for param in model.parameters(): for param in model.parameters():
param.trainable = False param.trainable = False
......
...@@ -21,7 +21,7 @@ import copy ...@@ -21,7 +21,7 @@ import copy
__all__ = ['build_post_process'] __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
...@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None): ...@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode' 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -187,3 +187,44 @@ class DBPostProcess(object): ...@@ -187,3 +187,44 @@ class DBPostProcess(object):
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
class DistillationDBPostProcess(DBPostProcess):
def __init__(self,
model_name=["student"],
key=None,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
**kwargs):
super(DistillationDBPostProcess, self).__init__(thresh,
box_thresh,
max_candidates,
unclip_ratio,
use_dilation,
score_mode)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def forward(self, predicts, shape_list):
results = {}
for name in self.model_name:
pred = predicts[name]
if self.key is not None:
pred = pred[self.key]
results[name] = super().__call__(pred, shape_list=label)
return results
...@@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer): ...@@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer):
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
return {} return {}
def load_pretrained_params(model, path):
if path is None:
return False
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
print(f"The pretrained_model {path} does not exists!")
return False
path = path if path.endswith('.pdparams') else path + '.pdparams'
params = paddle.load(path)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
print(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
return True
def save_model(model, def save_model(model,
optimizer, optimizer,
......
...@@ -186,7 +186,10 @@ def train(config, ...@@ -186,7 +186,10 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] try:
model_type = config['Architecture']['model_type']
except:
model_type = None
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册