提交 21d083e2 编写于 作者: C chenguowei01

fix dice channel wrong

上级 2ce4a558
...@@ -89,6 +89,8 @@ class HRNet(object): ...@@ -89,6 +89,8 @@ class HRNet(object):
self.stage4_num_channels = stage4_num_channels self.stage4_num_channels = stage4_num_channels
def build_net(self, inputs): def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image'] image = inputs['image']
logit = self._high_resolution_net(image, self.num_classes) logit = self._high_resolution_net(image, self.num_classes)
if self.num_classes == 1: if self.num_classes == 1:
......
...@@ -81,7 +81,7 @@ class ShuffleSeg(object): ...@@ -81,7 +81,7 @@ class ShuffleSeg(object):
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
return inputs return inputs
def build_net(self, inputs, class_dim=2): def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss: if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1 self.num_classes = 1
image = inputs['image'] image = inputs['image']
...@@ -124,7 +124,7 @@ class ShuffleSeg(object): ...@@ -124,7 +124,7 @@ class ShuffleSeg(object):
conv_b = fluid.layers.resize_bilinear(conv, shortcut_shape) conv_b = fluid.layers.resize_bilinear(conv, shortcut_shape)
concat = fluid.layers.concat([shortcut, conv_b], axis=1) concat = fluid.layers.concat([shortcut, conv_b], axis=1)
decode_conv = self.depthwise_separable(concat, 3, 64, 1) decode_conv = self.depthwise_separable(concat, 3, 64, 1)
logit = self.output_layer(decode_conv, class_dim) logit = self.output_layer(decode_conv, self.num_classes)
if self.num_classes == 1: if self.num_classes == 1:
out = sigmoid_to_softmax(logit) out = sigmoid_to_softmax(logit)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册