diff --git a/utils/utils.py b/utils/utils.py index d92e7056de21ab6d6b05f35f9aa5d33299e477e0..e67dcc3f89a6c47707a2e27e6c00a0fee1f59d90 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -49,10 +49,10 @@ class DecodeBox(nn.Module): FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor - # 生成网格,先验框中心,网格左上角 - grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_width, 1).repeat( + # 生成网格,先验框中心,网格左上角 batch_size,3,13,13 + grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat( batch_size * self.num_anchors, 1, 1).view(x.shape).type(FloatTensor) - grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_height, 1).t().repeat( + grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat( batch_size * self.num_anchors, 1, 1).view(y.shape).type(FloatTensor) # 生成先验框的宽高