diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml index 46daeeb86d004772a6fb964d602369dcd53b3a01..d8d5135dd73ee438f76f5796b63e0dae4331402b 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml @@ -90,7 +90,7 @@ Optimizer: PostProcess: name: DistillationDBPostProcess - model_name: ["Student", "Student2"] + model_name: ["Student"] key: head_out thresh: 0.3 box_thresh: 0.6 diff --git a/ppocr/data/imaug/east_process.py b/ppocr/data/imaug/east_process.py index b1d7a5e51939af981dd62c269c930f4bf9ba4179..df08adfa1516c59229e95af193c172dfcdd5af08 100644 --- a/ppocr/data/imaug/east_process.py +++ b/ppocr/data/imaug/east_process.py @@ -11,7 +11,10 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - +""" +This code is refered from: +https://github.com/songdejia/EAST/blob/master/data_utils.py +""" import math import cv2 import numpy as np @@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain'] class EASTProcessTrain(object): def __init__(self, - image_shape = [512, 512], - background_ratio = 0.125, - min_crop_side_ratio = 0.1, - min_text_size = 10, + image_shape=[512, 512], + background_ratio=0.125, + min_crop_side_ratio=0.1, + min_text_size=10, **kwargs): self.input_size = image_shape[1] self.random_scale = np.array([0.5, 1, 2.0, 3.0]) @@ -282,12 +285,7 @@ class EASTProcessTrain(object): 1.0 / max(min(poly_h, poly_w), 1.0) return score_map, geo_map, training_mask - def crop_area(self, - im, - polys, - tags, - crop_background=False, - max_tries=50): + def crop_area(self, im, polys, tags, crop_background=False, max_tries=50): """ make random crop from the input image :param im: @@ -435,5 +433,4 @@ class EASTProcessTrain(object): data['score_map'] = score_map data['geo_map'] = geo_map data['training_mask'] = training_mask - # print(im.shape, score_map.shape, geo_map.shape, training_mask.shape) - return data \ No newline at end of file + return data diff --git a/ppocr/data/imaug/sast_process.py b/ppocr/data/imaug/sast_process.py index 1536dceb8ee5999226cfe7cf455d70e39b449530..08d03b194dcfab92ab59329857d4a1326531218e 100644 --- a/ppocr/data/imaug/sast_process.py +++ b/ppocr/data/imaug/sast_process.py @@ -11,7 +11,10 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - +""" +This part code is refered from: +https://github.com/songdejia/EAST/blob/master/data_utils.py +""" import math import cv2 import numpy as np diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index d9c9869dfcd35cb9b491db826f3bff5f766723f4..27b428ef2e73c9abf81d3881b23979343c8595b2 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refered from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -190,7 +193,8 @@ class DBPostProcess(object): class DistillationDBPostProcess(object): - def __init__(self, model_name=["student"], + def __init__(self, + model_name=["student"], key=None, thresh=0.3, box_thresh=0.6, @@ -201,12 +205,13 @@ class DistillationDBPostProcess(object): **kwargs): self.model_name = model_name self.key = key - self.post_process = DBPostProcess(thresh=thresh, - box_thresh=box_thresh, - max_candidates=max_candidates, - unclip_ratio=unclip_ratio, - use_dilation=use_dilation, - score_mode=score_mode) + self.post_process = DBPostProcess( + thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) def __call__(self, predicts, shape_list): results = {}