diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index d9c9869dfcd35cb9b491db826f3bff5f766723f4..c118a378b7dcd87a3d18a18f97de1a8ba175c3d0 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -190,7 +193,8 @@ class DBPostProcess(object): class DistillationDBPostProcess(object): - def __init__(self, model_name=["student"], + def __init__(self, + model_name=["student"], key=None, thresh=0.3, box_thresh=0.6, @@ -201,12 +205,13 @@ class DistillationDBPostProcess(object): **kwargs): self.model_name = model_name self.key = key - self.post_process = DBPostProcess(thresh=thresh, - box_thresh=box_thresh, - max_candidates=max_candidates, - unclip_ratio=unclip_ratio, - use_dilation=use_dilation, - score_mode=score_mode) + self.post_process = DBPostProcess( + thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) def __call__(self, predicts, shape_list): results = {}