diff --git a/ppocr/modeling/architectures/det_model.py b/ppocr/modeling/architectures/det_model.py index 5016546ac08d06d361d0352498ac853084e61bca..516413e9920b94d3b53f078fe4a73188a9435068 100755 --- a/ppocr/modeling/architectures/det_model.py +++ b/ppocr/modeling/architectures/det_model.py @@ -32,6 +32,7 @@ class DetModel(object): params (dict): the super parameters for detection module. """ global_params = params['Global'] + self.global_params = global_params self.algorithm = global_params['algorithm'] backbone_params = deepcopy(params["Backbone"]) @@ -64,11 +65,23 @@ class DetModel(object): if mode == "train": if self.algorithm == "EAST": score = fluid.layers.data( - name='score', shape=[1, 128, 128], dtype='float32') + name='score', + shape=[ + 1, int(image_shape[1] // 4), int(image_shape[2] // 4) + ], + dtype='float32') geo = fluid.layers.data( - name='geo', shape=[9, 128, 128], dtype='float32') + name='geo', + shape=[ + 9, int(image_shape[1] // 4), int(image_shape[2] // 4) + ], + dtype='float32') mask = fluid.layers.data( - name='mask', shape=[1, 128, 128], dtype='float32') + name='mask', + shape=[ + 1, int(image_shape[1] // 4), int(image_shape[2] // 4) + ], + dtype='float32') feed_list = [image, score, geo, mask] labels = {'score': score, 'geo': geo, 'mask': mask} elif self.algorithm == "DB":