未验证 提交 0afa869f 编写于 作者: J jerrywgz 提交者: GitHub

fix overflow error in rcnn (#2238)

上级 4996edbc
...@@ -76,7 +76,7 @@ class RCNN(object): ...@@ -76,7 +76,7 @@ class RCNN(object):
[-1, 3], [-1, 1]] [-1, 3], [-1, 1]]
lod_levels = [0, 1, 1, 1, 0, 0] lod_levels = [0, 1, 1, 1, 0, 0]
dtypes = [ dtypes = [
'float32', 'float32', 'int32', 'int32', 'float32', 'int32' 'float32', 'float32', 'int32', 'int32', 'float32', 'int64'
] ]
if cfg.MASK_ON: if cfg.MASK_ON:
in_shapes.append([-1, 2]) in_shapes.append([-1, 2])
...@@ -109,7 +109,7 @@ class RCNN(object): ...@@ -109,7 +109,7 @@ class RCNN(object):
self.im_info = fluid.layers.data( self.im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32') name='im_info', shape=[3], dtype='float32')
self.im_id = fluid.layers.data( self.im_id = fluid.layers.data(
name='im_id', shape=[1], dtype='int32') name='im_id', shape=[1], dtype='int64')
if cfg.MASK_ON: if cfg.MASK_ON:
self.gt_masks = fluid.layers.data( self.gt_masks = fluid.layers.data(
name='gt_masks', shape=[2], dtype='float32', lod_level=3) name='gt_masks', shape=[2], dtype='float32', lod_level=3)
......
...@@ -153,14 +153,14 @@ class JsonDataset(object): ...@@ -153,14 +153,14 @@ class JsonDataset(object):
num_valid_objs = len(valid_objs) num_valid_objs = len(valid_objs)
gt_boxes = np.zeros((num_valid_objs, 4), dtype=entry['gt_boxes'].dtype) gt_boxes = np.zeros((num_valid_objs, 4), dtype=entry['gt_boxes'].dtype)
gt_id = np.zeros((num_valid_objs), dtype=np.int32) gt_id = np.zeros((num_valid_objs), dtype=np.int64)
gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype) gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype)
is_crowd = np.zeros((num_valid_objs), dtype=entry['is_crowd'].dtype) is_crowd = np.zeros((num_valid_objs), dtype=entry['is_crowd'].dtype)
for ix, obj in enumerate(valid_objs): for ix, obj in enumerate(valid_objs):
cls = self.json_category_id_to_contiguous_id[obj['category_id']] cls = self.json_category_id_to_contiguous_id[obj['category_id']]
gt_boxes[ix, :] = obj['clean_bbox'] gt_boxes[ix, :] = obj['clean_bbox']
gt_classes[ix] = cls gt_classes[ix] = cls
gt_id[ix] = np.int32(obj['id']) gt_id[ix] = np.int64(obj['id'])
is_crowd[ix] = obj['iscrowd'] is_crowd[ix] = obj['iscrowd']
entry['gt_boxes'] = np.append(entry['gt_boxes'], gt_boxes, axis=0) entry['gt_boxes'] = np.append(entry['gt_boxes'], gt_boxes, axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册