提交 d22e3e5e 编写于 作者: C chenguowei01

use softmax_with_cross_entropy

上级 bfec9c85
......@@ -15,11 +15,11 @@
import paddle.fluid as fluid
def constant_init(param, value=0.0):
initializer = fluid.initializer.Constant(value)
def constant_init(param, **kwargs):
initializer = fluid.initializer.Constant(**kwargs)
initializer(param, param.block)
def normal_init(param, loc=0.0, scale=1.0, seed=0):
initializer = fluid.initializer.Normal(loc=loc, scale=scale, seed=seed)
def normal_init(param, **kwargs):
initializer = fluid.initializer.Normal(**kwargs)
initializer(param, param.block)
......@@ -146,7 +146,8 @@ class HRNet(fluid.dygraph.Layer):
has_se=self.has_se,
name="st4")
self.init_weight(backbone_pretrained)
if self.training:
self.init_weight(backbone_pretrained)
def forward(self, x, label=None, mode='train'):
input_shape = x.shape[2:]
......
......@@ -86,7 +86,8 @@ class FCN(fluid.dygraph.Layer):
filter_size=1,
stride=1,
padding=0)
self.init_weight(model_pretrained)
if self.training:
self.init_weight(model_pretrained)
def forward(self, x):
input_shape = x.shape[2:]
......@@ -132,36 +133,6 @@ class FCN(fluid.dygraph.Layer):
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
# def _get_loss(self, logit, label):
# """
# compute forward loss of the model
# Args:
# logit (tensor): the logit of model output
# label (tensor): ground truth
# Returns:
# avg_loss (tensor): forward loss
# """
# logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
# label = fluid.layers.transpose(label, [0, 2, 3, 1])
# mask = label != self.ignore_index
# mask = fluid.layers.cast(mask, 'float32')
# loss, probs = fluid.layers.softmax_with_cross_entropy(
# logit,
# label,
# ignore_index=self.ignore_index,
# return_softmax=True,
# axis=-1)
# loss = loss * mask
# avg_loss = fluid.layers.mean(loss) / (
# fluid.layers.mean(mask) + self.EPS)
# label.stop_gradient = True
# mask.stop_gradient = True
# return avg_loss
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
......
......@@ -17,8 +17,7 @@ from paddle import nn
import paddle.nn.functional as F
from dygraph.cvlibs import manager
'''
@manager.LOSSES.add_component
class CrossEntropyLoss(nn.CrossEntropyLoss):
"""
......@@ -40,8 +39,9 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
"""
def __init__(self, weight=None, ignore_index=255, reduction='mean'):
super(CrossEntropyLoss, self).__init__(
weight=weight, ignore_index=ignore_index, reduction=reduction)
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
self.EPS = 1e-5
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
......@@ -71,6 +71,49 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
mask = paddle.cast(mask, 'float32')
avg_loss = loss / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
'''
@manager.LOSSES.add_component
class CrossEntropyLoss(nn.Layer):
"""
Implements the cross entropy loss function.
Args:
ignore_index (int64): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
"""
def __init__(self, ignore_index=255):
super(CrossEntropyLoss, self).__init__()
self.ignore_index = ignore_index
self.EPS = 1e-5
def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
"""
if len(label.shape) != len(logit.shape):
label = paddle.unsqueeze(label, 1)
loss = F.softmax_with_cross_entropy(
logit, label, ignore_index=self.ignore_index, axis=1)
loss = paddle.reduce_mean(loss)
mask = label != self.ignore_index
mask = paddle.cast(mask, 'float32')
avg_loss = loss / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册