未验证 提交 079cdf98 编写于 作者: Z zhoujun 提交者: GitHub

update save model func (#6693)

* save latest metric

* save latest metric

* add boader judge

* add boader judge
上级 59d854c6
...@@ -48,6 +48,7 @@ class Shape(object): ...@@ -48,6 +48,7 @@ class Shape(object):
def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False): def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False):
self.label = label self.label = label
self.idx = 0
self.points = [] self.points = []
self.fill = False self.fill = False
self.selected = False self.selected = False
......
...@@ -311,7 +311,6 @@ python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o ...@@ -311,7 +311,6 @@ python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o
在上述命令中,通过`-o`的方式修改了配置文件中的参数。 在上述命令中,通过`-o`的方式修改了配置文件中的参数。
训练好的模型地址为: [det_ppocr_v3_finetune.tar](https://paddleocr.bj.bcebos.com/fanliku/license_plate_recognition/det_ppocr_v3_finetune.tar)
**评估** **评估**
...@@ -354,8 +353,6 @@ python3.7 deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3/ch_PP-OCR ...@@ -354,8 +353,6 @@ python3.7 deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3/ch_PP-OCR
Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt] Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt]
``` ```
训练好的模型地址为: [det_ppocr_v3_quant.tar](https://paddleocr.bj.bcebos.com/fanliku/license_plate_recognition/det_ppocr_v3_quant.tar)
量化后指标对比如下 量化后指标对比如下
|方案|hmeans| 模型大小 | 预测速度(lite) | |方案|hmeans| 模型大小 | 预测速度(lite) |
...@@ -436,6 +433,12 @@ python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ ...@@ -436,6 +433,12 @@ python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \
Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt]
``` ```
如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
<div align="left">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
评估部分日志如下: 评估部分日志如下:
```bash ```bash
[2022/05/12 19:52:02] ppocr INFO: load pretrain successful from models/ch_PP-OCRv3_rec_train/best_accuracy [2022/05/12 19:52:02] ppocr INFO: load pretrain successful from models/ch_PP-OCRv3_rec_train/best_accuracy
...@@ -528,7 +531,6 @@ python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ ...@@ -528,7 +531,6 @@ python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \
Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \
Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt]
``` ```
训练好的模型地址为: [rec_ppocr_v3_finetune.tar](https://paddleocr.bj.bcebos.com/fanliku/license_plate_recognition/rec_ppocr_v3_finetune.tar)
**评估** **评估**
...@@ -570,7 +572,6 @@ python3.7 deploy/slim/quantization/quant.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_ ...@@ -570,7 +572,6 @@ python3.7 deploy/slim/quantization/quant.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_
Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \
Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt]
``` ```
训练好的模型地址为: [rec_ppocr_v3_quant.tar](https://paddleocr.bj.bcebos.com/fanliku/license_plate_recognition/rec_ppocr_v3_quant.tar)
量化后指标对比如下 量化后指标对比如下
......
...@@ -107,17 +107,20 @@ class FCENetTargets: ...@@ -107,17 +107,20 @@ class FCENetTargets:
for i in range(1, n): for i in range(1, n):
current_line_len = i * delta_length current_line_len = i * delta_length
while current_line_len >= length_cumsum[current_edge_ind + 1]: while current_edge_ind + 1 < len(length_cumsum) and current_line_len >= length_cumsum[current_edge_ind + 1]:
current_edge_ind += 1 current_edge_ind += 1
current_edge_end_shift = current_line_len - length_cumsum[ current_edge_end_shift = current_line_len - length_cumsum[
current_edge_ind] current_edge_ind]
if current_edge_ind >= len(length_list):
break
end_shift_ratio = current_edge_end_shift / length_list[ end_shift_ratio = current_edge_end_shift / length_list[
current_edge_ind] current_edge_ind]
current_point = line[current_edge_ind] + (line[current_edge_ind + 1] current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
- line[current_edge_ind] - line[current_edge_ind]
) * end_shift_ratio ) * end_shift_ratio
resampled_line.append(current_point) resampled_line.append(current_point)
resampled_line.append(line[-1]) resampled_line.append(line[-1])
resampled_line = np.array(resampled_line) resampled_line = np.array(resampled_line)
...@@ -328,6 +331,8 @@ class FCENetTargets: ...@@ -328,6 +331,8 @@ class FCENetTargets:
resampled_top_line, resampled_bot_line = self.resample_sidelines( resampled_top_line, resampled_bot_line = self.resample_sidelines(
top_line, bot_line, self.resample_step) top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1] resampled_bot_line = resampled_bot_line[::-1]
if len(resampled_top_line) != len(resampled_bot_line):
continue
center_line = (resampled_top_line + resampled_bot_line) / 2 center_line = (resampled_top_line + resampled_bot_line) / 2
line_head_shrink_len = norm(resampled_top_line[0] - line_head_shrink_len = norm(resampled_top_line[0] -
......
...@@ -177,9 +177,9 @@ def save_model(model, ...@@ -177,9 +177,9 @@ def save_model(model,
model.backbone.model.save_pretrained(model_prefix) model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric') metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config # save metric and config
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
if is_best: if is_best:
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix)) logger.info('save best model is to {}'.format(model_prefix))
else: else:
logger.info("save model in {}".format(model_prefix)) logger.info("save model in {}".format(model_prefix))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册