提交 185d1e1f 编写于 作者: L LDOUBLEV

fix bug

上级 a91bbd74
......@@ -20,7 +20,7 @@ Architecture:
algorithm: Distillation
Models:
Student:
pretrained:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
......@@ -37,7 +37,7 @@ Architecture:
name: DBHead
k: 50
Student2:
pretrained:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
......@@ -55,6 +55,9 @@ Architecture:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
......@@ -73,7 +76,9 @@ Loss:
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2", "Teacher"]
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
......@@ -81,13 +86,16 @@ Loss:
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: ["thrink_maps"]
weight: 1.0
act: "softmax"
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
model_name_list: ["Student", "Teacher"]
weight: 1.0
model_name_list: ["Student", "Student2"]
key: maps
name: DBLoss
balance_loss: true
......@@ -110,7 +118,7 @@ Optimizer:
factor: 0
PostProcess:
name: DistillationCTDBPostProcessCLabelDecode
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
......
......@@ -14,6 +14,8 @@
import paddle
import paddle.nn as nn
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss
......@@ -22,6 +24,7 @@ from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
......@@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss):
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
self.maps_name = self.maps_name
self.maps_name = maps_name
def _check_maps_name(self, maps_name):
if maps_name is None:
......@@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss):
class DistillationDilaDBLoss(DBLoss):
def __init__(self,
model_name_pairs=[],
key=None,
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
......@@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss):
super().__init__()
self.model_name_pairs = model_name_pairs
self.name = name
self.key = key
def forward(self, predicts, batch):
loss_dict = dict()
......@@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss):
loss_dict[k] = bce_loss + loss_binary_maps
loss_dict = _sum_loss(loss_dict)
return loss
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
......
......@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import init_model, load_pretrained_params
__all__ = ['DistillationModel']
......@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
init_model(model, path=pretrained)
load_pretrained_params(model, pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
......
......@@ -21,7 +21,7 @@ import copy
__all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
......@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode'
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
]
config = copy.deepcopy(config)
......
......@@ -187,3 +187,44 @@ class DBPostProcess(object):
boxes_batch.append({'points': boxes})
return boxes_batch
class DistillationDBPostProcess(DBPostProcess):
def __init__(self,
model_name=["student"],
key=None,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
**kwargs):
super(DistillationDBPostProcess, self).__init__(thresh,
box_thresh,
max_candidates,
unclip_ratio,
use_dilation,
score_mode)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def forward(self, predicts, shape_list):
results = {}
for name in self.model_name:
pred = predicts[name]
if self.key is not None:
pred = pred[self.key]
results[name] = super().__call__(pred, shape_list=label)
return results
......@@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer):
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
def load_pretrained_params(model, path):
if path is None:
return False
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
print(f"The pretrained_model {path} does not exists!")
return False
path = path if path.endswith('.pdparams') else path + '.pdparams'
params = paddle.load(path)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
print(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
return True
def save_model(model,
optimizer,
......
......@@ -186,7 +186,10 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
try:
model_type = config['Architecture']['model_type']
except:
model_type = None
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册