未验证 提交 1046c239 编写于 作者: W wanghuancoder 提交者: GitHub

fix numpy speed (#10773)

上级 b3912fcf
...@@ -745,6 +745,8 @@ class DistillationDilaDBLoss(DBLoss): ...@@ -745,6 +745,8 @@ class DistillationDilaDBLoss(DBLoss):
# dilation to teacher prediction # dilation to teacher prediction
dilation_w = np.array([[1, 1], [1, 1]]) dilation_w = np.array([[1, 1], [1, 1]])
th_shrink_maps = tch_preds[:, 0, :, :] th_shrink_maps = tch_preds[:, 0, :, :]
if hasattr(paddle.Tensor, "contiguous"):
th_shrink_maps = th_shrink_maps.contiguous()
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3 th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32) dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
for i in range(th_shrink_maps.shape[0]): for i in range(th_shrink_maps.shape[0]):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册