提交 f4094b97 编写于 作者: L LDOUBLEV

fix issue #206

上级 1abb0e4a
...@@ -32,6 +32,7 @@ class DetModel(object): ...@@ -32,6 +32,7 @@ class DetModel(object):
params (dict): the super parameters for detection module. params (dict): the super parameters for detection module.
""" """
global_params = params['Global'] global_params = params['Global']
self.global_params = global_params
self.algorithm = global_params['algorithm'] self.algorithm = global_params['algorithm']
backbone_params = deepcopy(params["Backbone"]) backbone_params = deepcopy(params["Backbone"])
...@@ -64,11 +65,23 @@ class DetModel(object): ...@@ -64,11 +65,23 @@ class DetModel(object):
if mode == "train": if mode == "train":
if self.algorithm == "EAST": if self.algorithm == "EAST":
score = fluid.layers.data( score = fluid.layers.data(
name='score', shape=[1, 128, 128], dtype='float32') name='score',
shape=[
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
geo = fluid.layers.data( geo = fluid.layers.data(
name='geo', shape=[9, 128, 128], dtype='float32') name='geo',
shape=[
9, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
mask = fluid.layers.data( mask = fluid.layers.data(
name='mask', shape=[1, 128, 128], dtype='float32') name='mask',
shape=[
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
feed_list = [image, score, geo, mask] feed_list = [image, score, geo, mask]
labels = {'score': score, 'geo': geo, 'mask': mask} labels = {'score': score, 'geo': geo, 'mask': mask}
elif self.algorithm == "DB": elif self.algorithm == "DB":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册