提交 b9d60c56 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3746 add klDivLoss python primitive

Merge pull request !3746 from baihuawei/loss
......@@ -878,6 +878,17 @@ def get_bprop_binary_cross_entropy(self):
return bprop
@bprop_getters.register(P.KLDivLoss)
def get_bprop_kl_div_loss(self):
"""Grad definition for `KLDivLoss` operation."""
grad = G.KLDivLossGrad(self.reduction)
def bprop(x, y, out, dout):
dx, dy = grad(x, y, dout)
return dx, dy
return bprop
@bprop_getters.register(P.Dropout)
def get_bprop_dropout(self):
......
......@@ -73,7 +73,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
......@@ -307,6 +307,7 @@ __all__ = [
"LSTM",
"Abs",
"BinaryCrossEntropy",
"KLDivLoss",
"SparseApplyAdagrad",
"SparseApplyAdagradV2",
"SpaceToDepth",
......
......@@ -144,6 +144,23 @@ class BiasAddGrad(Primitive):
raise NotImplementedError
class KLDivLossGrad(PrimitiveWithInfer):
"""Computes gradients for `KLDivLoss` operation."""
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
def infer_shape(self, x_shape, y_shape, doutput_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
return x_shape, y_shape
def infer_dtype(self, x_type, y_type, doutput_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return x_type, y_type
class BinaryCrossEntropyGrad(PrimitiveWithInfer):
"""Computes gradients for `BinaryCrossEntropy` operation."""
......@@ -405,6 +422,7 @@ class FusedBatchNormGrad(Primitive):
def __call__(self, dy, x, scale, save_mean, save_inv_variance):
raise NotImplementedError
class BNTrainingReduceGrad(PrimitiveWithInfer):
"""Gradients of FusedBatchNorm operation."""
......@@ -419,6 +437,7 @@ class BNTrainingReduceGrad(PrimitiveWithInfer):
def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
return grads
class BNTrainingUpdateGrad(PrimitiveWithInfer):
"""Gradients of FusedBatchNorm operation."""
......@@ -433,6 +452,7 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer):
def infer_dtype(self, grads, x, batch_mean, batch_variance):
return (batch_mean, batch_variance)
class GeluGrad(PrimitiveWithInfer):
"""Gradients of Gelu operation."""
......@@ -1336,6 +1356,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
......@@ -1536,6 +1557,7 @@ class InvGrad(PrimitiveWithInfer):
class LRNGrad(PrimitiveWithInfer):
"""Computes gradients for LRN operation."""
@prim_attr_register
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
......
......@@ -3367,6 +3367,76 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
return var_dtype, accum_dtype
class KLDivLoss(PrimitiveWithInfer):
r"""
Computes the Kullback-Leibler divergence between the target and the output.
Note:
Sets input as :math:`x`, input label as :math:`y`, output as :math:`\ell(x, y)`.
Let,
.. math::
L = \{l_1,\dots,l_N\}^\top, \quad
l_n = y_n \cdot (\log y_n - x_n)
Then,
.. math::
\ell(x, y) = \begin{cases}
L, & \text{if reduction} = \text{'none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}
Args:
reduction (str): Specifies the reduction to apply to the output.
Its value should be one of 'none', 'mean', 'sum'. Default: 'mean'.
Inputs:
- **input_x** (Tensor) - The input Tensor. The data type must be float32.
- **input_y** (Tensor) - The label Tensor which has same shape as `input_x`. The data type must be float32.
Outputs:
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and same shape as `input_x`.
Otherwise it is a scalar.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.kldiv_loss = P.KLDivLoss()
>>> def construct(self, x, y):
>>> result = self.kldiv_loss(x, y)
>>> return result
>>>
>>> net = Net()
>>> input_x = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
>>> input_y = Tensor(np.array([0., 1., 0.]), mindspore.float32)
>>> result = net(input_x, input_y)
"""
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
def infer_shape(self, x_shape, y_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
if self.reduction in ('mean', 'sum'):
shape = []
else:
shape = x_shape
return shape
def infer_dtype(self, x_type, y_type):
args = {'x': x_type, 'y': y_type}
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(args, valid_types, self.name)
return x_type
class BinaryCrossEntropy(PrimitiveWithInfer):
r"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册