提交 d652a1bd 编写于 作者: M michaelowenliu

fix a type in aspp

上级 d2b8644c
......@@ -79,13 +79,15 @@ class ASPPModule(nn.Layer):
outputs = []
for block in self.aspp_blocks:
outputs.append(block(x))
y = block(x)
y = F.resize_bilinear(y, out_shape=x.shape[2:])
outputs.append(y)
if self.image_pooling:
img_avg = self.global_avg_pool(x)
img_avg = F.resize_bilinear(img_avg, out_shape=x.shape[2:])
outputs.append(img_avg)
x = paddle.concat(outputs, axis=1)
x = self.conv_bn_relu(x)
x = self.dropout(x)
......
......@@ -197,7 +197,7 @@ class DeepLabV3Head(nn.Layer):
image_pooling=True)
self.cls = nn.Conv2d(
in_channels=backbone_channels[0],
in_channels=aspp_out_channels,
out_channels=num_classes,
kernel_size=1)
......@@ -209,6 +209,7 @@ class DeepLabV3Head(nn.Layer):
logit_list = []
x = feat_list[self.backbone_indices[0]]
x = self.aspp(x)
logit = self.cls(x)
logit_list.append(logit)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册