未验证 提交 c22b29d6 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #4558 from LDOUBLEV/release/2.3

[fix infer] cp 
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
"""
This code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import math import math
import cv2 import cv2
import numpy as np import numpy as np
...@@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain'] ...@@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain']
class EASTProcessTrain(object): class EASTProcessTrain(object):
def __init__(self, def __init__(self,
image_shape = [512, 512], image_shape=[512, 512],
background_ratio = 0.125, background_ratio=0.125,
min_crop_side_ratio = 0.1, min_crop_side_ratio=0.1,
min_text_size = 10, min_text_size=10,
**kwargs): **kwargs):
self.input_size = image_shape[1] self.input_size = image_shape[1]
self.random_scale = np.array([0.5, 1, 2.0, 3.0]) self.random_scale = np.array([0.5, 1, 2.0, 3.0])
...@@ -282,12 +285,7 @@ class EASTProcessTrain(object): ...@@ -282,12 +285,7 @@ class EASTProcessTrain(object):
1.0 / max(min(poly_h, poly_w), 1.0) 1.0 / max(min(poly_h, poly_w), 1.0)
return score_map, geo_map, training_mask return score_map, geo_map, training_mask
def crop_area(self, def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
im,
polys,
tags,
crop_background=False,
max_tries=50):
""" """
make random crop from the input image make random crop from the input image
:param im: :param im:
...@@ -436,4 +434,4 @@ class EASTProcessTrain(object): ...@@ -436,4 +434,4 @@ class EASTProcessTrain(object):
data['geo_map'] = geo_map data['geo_map'] = geo_map
data['training_mask'] = training_mask data['training_mask'] = training_mask
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape) # print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
return data return data
\ No newline at end of file
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
"""
This part code is refered from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import math import math
import cv2 import cv2
import numpy as np import numpy as np
......
...@@ -11,7 +11,10 @@ ...@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
This code is refered from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -190,7 +193,8 @@ class DBPostProcess(object): ...@@ -190,7 +193,8 @@ class DBPostProcess(object):
class DistillationDBPostProcess(object): class DistillationDBPostProcess(object):
def __init__(self, model_name=["student"], def __init__(self,
model_name=["student"],
key=None, key=None,
thresh=0.3, thresh=0.3,
box_thresh=0.6, box_thresh=0.6,
...@@ -201,12 +205,13 @@ class DistillationDBPostProcess(object): ...@@ -201,12 +205,13 @@ class DistillationDBPostProcess(object):
**kwargs): **kwargs):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
self.post_process = DBPostProcess(thresh=thresh, self.post_process = DBPostProcess(
box_thresh=box_thresh, thresh=thresh,
max_candidates=max_candidates, box_thresh=box_thresh,
unclip_ratio=unclip_ratio, max_candidates=max_candidates,
use_dilation=use_dilation, unclip_ratio=unclip_ratio,
score_mode=score_mode) use_dilation=use_dilation,
score_mode=score_mode)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册