未验证 提交 10669d37 编写于 作者: W wuyefeilin 提交者: GitHub

fix channel wrong of dice loss

fix channel wrong of dice loss 
......@@ -89,6 +89,8 @@ class HRNet(object):
self.stage4_num_channels = stage4_num_channels
def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
logit = self._high_resolution_net(image, self.num_classes)
if self.num_classes == 1:
......
......@@ -81,7 +81,7 @@ class ShuffleSeg(object):
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def build_net(self, inputs, class_dim=2):
def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
......@@ -124,7 +124,7 @@ class ShuffleSeg(object):
conv_b = fluid.layers.resize_bilinear(conv, shortcut_shape)
concat = fluid.layers.concat([shortcut, conv_b], axis=1)
decode_conv = self.depthwise_separable(concat, 3, 64, 1)
logit = self.output_layer(decode_conv, class_dim)
logit = self.output_layer(decode_conv, self.num_classes)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册