From f4094b9700c4028afbb1d88110361156e8c41643 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Tue, 16 Jun 2020 14:55:52 +0800 Subject: [PATCH] fix issue #206 --- ppocr/modeling/architectures/det_model.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/ppocr/modeling/architectures/det_model.py b/ppocr/modeling/architectures/det_model.py index 5016546a..516413e9 100755 --- a/ppocr/modeling/architectures/det_model.py +++ b/ppocr/modeling/architectures/det_model.py @@ -32,6 +32,7 @@ class DetModel(object): params (dict): the super parameters for detection module. """ global_params = params['Global'] + self.global_params = global_params self.algorithm = global_params['algorithm'] backbone_params = deepcopy(params["Backbone"]) @@ -64,11 +65,23 @@ class DetModel(object): if mode == "train": if self.algorithm == "EAST": score = fluid.layers.data( - name='score', shape=[1, 128, 128], dtype='float32') + name='score', + shape=[ + 1, int(image_shape[1] // 4), int(image_shape[2] // 4) + ], + dtype='float32') geo = fluid.layers.data( - name='geo', shape=[9, 128, 128], dtype='float32') + name='geo', + shape=[ + 9, int(image_shape[1] // 4), int(image_shape[2] // 4) + ], + dtype='float32') mask = fluid.layers.data( - name='mask', shape=[1, 128, 128], dtype='float32') + name='mask', + shape=[ + 1, int(image_shape[1] // 4), int(image_shape[2] // 4) + ], + dtype='float32') feed_list = [image, score, geo, mask] labels = {'score': score, 'geo': geo, 'mask': mask} elif self.algorithm == "DB": -- GitLab