diff --git a/utils/utils.py b/utils/utils.py index 495bb1b3428b5af4bb1d654e4317d00442203f38..f85185a3cd3d44c76b793f0f4af792685f075f69 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -60,9 +60,9 @@ class DecodeBox(nn.Module): LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor # 生成网格,先验框中心,网格左上角 batch_size,3,13,13 - grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_width, 1).repeat( + grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat( batch_size * self.num_anchors, 1, 1).view(x.shape).type(FloatTensor) - grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_height, 1).t().repeat( + grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat( batch_size * self.num_anchors, 1, 1).view(y.shape).type(FloatTensor) # 生成先验框的宽高