未验证 提交 0afa869f 编写于 作者: J jerrywgz 提交者: GitHub

fix overflow error in rcnn (#2238)

上级 4996edbc
......@@ -76,7 +76,7 @@ class RCNN(object):
[-1, 3], [-1, 1]]
lod_levels = [0, 1, 1, 1, 0, 0]
dtypes = [
'float32', 'float32', 'int32', 'int32', 'float32', 'int32'
'float32', 'float32', 'int32', 'int32', 'float32', 'int64'
]
if cfg.MASK_ON:
in_shapes.append([-1, 2])
......@@ -109,7 +109,7 @@ class RCNN(object):
self.im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32')
self.im_id = fluid.layers.data(
name='im_id', shape=[1], dtype='int32')
name='im_id', shape=[1], dtype='int64')
if cfg.MASK_ON:
self.gt_masks = fluid.layers.data(
name='gt_masks', shape=[2], dtype='float32', lod_level=3)
......
......@@ -153,14 +153,14 @@ class JsonDataset(object):
num_valid_objs = len(valid_objs)
gt_boxes = np.zeros((num_valid_objs, 4), dtype=entry['gt_boxes'].dtype)
gt_id = np.zeros((num_valid_objs), dtype=np.int32)
gt_id = np.zeros((num_valid_objs), dtype=np.int64)
gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype)
is_crowd = np.zeros((num_valid_objs), dtype=entry['is_crowd'].dtype)
for ix, obj in enumerate(valid_objs):
cls = self.json_category_id_to_contiguous_id[obj['category_id']]
gt_boxes[ix, :] = obj['clean_bbox']
gt_classes[ix] = cls
gt_id[ix] = np.int32(obj['id'])
gt_id[ix] = np.int64(obj['id'])
is_crowd[ix] = obj['iscrowd']
entry['gt_boxes'] = np.append(entry['gt_boxes'], gt_boxes, axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册