未验证 提交 fa6c5a11 编写于 作者: G Guanghua Yu 提交者: GitHub

fix num_classes in solov2 (#3193)

上级 9532fbdc
......@@ -462,7 +462,8 @@ class SOLOv2Head(nn.Layer):
# cate_labels & kernel_preds
cate_labels = inds[:, 1]
kernel_preds = paddle.gather(kernel_preds, index=inds[:, 0])
cate_score_idx = paddle.add(inds[:, 0] * 80, cate_labels)
cate_score_idx = paddle.add(inds[:, 0] * self.cate_out_channels,
cate_labels)
cate_scores = paddle.gather(cate_preds, index=cate_score_idx)
size_trans = np.power(self.seg_num_grids, 2)
......
......@@ -367,8 +367,8 @@ class SOLOv2Head(object):
# cate_labels & kernel_preds
cate_labels = inds[:, 1]
kernel_preds = fluid.layers.gather(kernel_preds, index=inds[:, 0])
cate_score_idx = fluid.layers.elementwise_add(inds[:, 0] * 80,
cate_labels)
cate_score_idx = fluid.layers.elementwise_add(
inds[:, 0] * self.cate_out_channels, cate_labels)
cate_scores = fluid.layers.gather(cate_preds, index=cate_score_idx)
size_trans = np.power(self.seg_num_grids, 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册