kl.py 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings

import paddle
18 19 20 21 22
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
23
from paddle.distribution.laplace import Laplace
24
from paddle.distribution.lognormal import LogNormal
25
from paddle.distribution.normal import Normal
26
from paddle.distribution.uniform import Uniform
27
from paddle.fluid.framework import _non_static_mode
28 29 30 31 32 33 34 35 36 37 38 39

__all__ = ["register_kl", "kl_divergence"]

_REGISTER_TABLE = {}


def kl_divergence(p, q):
    r"""
    Kullback-Leibler divergence between distribution p and q.

    .. math::

40
        KL(p||q) = \int p(x)log\frac{p(x)}{q(x)} \mathrm{d}x
41 42

    Args:
43 44
        p (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
        q (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
45 46

    Returns:
47
        Tensor, Batchwise KL-divergence between distribution p and q.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

    Examples:

        .. code-block:: python

            import paddle

            p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
            q = paddle.distribution.Beta(alpha=0.3, beta=0.7)

            print(paddle.distribution.kl_divergence(p, q))
            # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #        [0.21193528])

    """
    return _dispatch(type(p), type(q))(p, q)


def register_kl(cls_p, cls_q):
    """Decorator for register a KL divergence implemention function.

69 70 71 72 73
    The ``kl_divergence(p, q)`` function will search concrete implemention
    functions registered by ``register_kl``, according to multi-dispatch pattern.
    If an implemention function is found, it will return the result, otherwise,
    it will raise ``NotImplementError`` exception. Users can register
    implemention funciton by the decorator.
74

75
    Args:
76 77
        cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``.
        cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``.
78 79 80 81 82 83 84 85 86 87

    Examples:
        .. code-block:: python

            import paddle

            @paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta)
            def kl_beta_beta():
                pass # insert implementation here
    """
88 89 90
    if not issubclass(cls_p, Distribution) or not issubclass(
        cls_q, Distribution
    ):
91 92 93 94 95 96 97 98 99 100
        raise TypeError('cls_p and cls_q must be subclass of Distribution')

    def decorator(f):
        _REGISTER_TABLE[cls_p, cls_q] = f
        return f

    return decorator


def _dispatch(cls_p, cls_q):
101
    """Multiple dispatch into concrete implement function."""
102 103

    # find all matched super class pair of p and q
104 105 106 107 108
    matchs = [
        (super_p, super_q)
        for super_p, super_q in _REGISTER_TABLE
        if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)
    ]
109 110 111 112 113 114 115 116
    if not matchs:
        raise NotImplementedError

    left_p, left_q = min(_Compare(*m) for m in matchs).classes
    right_p, right_q = min(_Compare(*reversed(m)) for m in matchs).classes

    if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]:
        warnings.warn(
117 118 119 120 121 122 123 124
            'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format(
                cls_p.__name__,
                cls_q.__name__,
                left_p.__name__,
                right_q.__name__,
            ),
            RuntimeWarning,
        )
125 126 127 128 129

    return _REGISTER_TABLE[left_p, left_q]


@functools.total_ordering
130
class _Compare:
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def __init__(self, *classes):
        self.classes = classes

    def __eq__(self, other):
        return self.classes == other.classes

    def __le__(self, other):
        for cls_x, cls_y in zip(self.classes, other.classes):
            if not issubclass(cls_x, cls_y):
                return False
            if cls_x is not cls_y:
                break
        return True


@register_kl(Beta, Beta)
def _kl_beta_beta(p, q):
148 149 150 151 152 153 154 155 156 157
    return (
        (q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma())
        - (p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma())
        + ((p.alpha - q.alpha) * p.alpha.digamma())
        + ((p.beta - q.beta) * p.beta.digamma())
        + (
            ((q.alpha + q.beta) - (p.alpha + p.beta))
            * (p.alpha + p.beta).digamma()
        )
    )
158 159 160 161 162


@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
    return (
163 164 165 166 167 168 169 170 171 172 173 174
        (p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma())
        - ((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1))
        + (
            (
                (p.concentration - q.concentration)
                * (
                    p.concentration.digamma()
                    - p.concentration.sum(-1).digamma().unsqueeze(-1)
                )
            ).sum(-1)
        )
    )
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191


@register_kl(Categorical, Categorical)
def _kl_categorical_categorical(p, q):
    return p.kl_divergence(q)


@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
    return p.kl_divergence(q)


@register_kl(Uniform, Uniform)
def _kl_uniform_uniform(p, q):
    return p.kl_divergence(q)


192 193 194 195 196
@register_kl(Laplace, Laplace)
def _kl_laplace_laplace(p, q):
    return p.kl_divergence(q)


197 198
@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
199
    """Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
200 201 202 203 204 205 206 207 208 209 210 211 212 213
    if not type(p) == type(q):
        raise NotImplementedError

    p_natural_params = []
    for param in p._natural_parameters:
        param = param.detach()
        param.stop_gradient = False
        p_natural_params.append(param)

    q_natural_params = q._natural_parameters

    p_log_norm = p._log_normalizer(*p_natural_params)

    try:
J
Jiabin Yang 已提交
214
        if _non_static_mode():
215 216 217
            p_grads = paddle.grad(
                p_log_norm, p_natural_params, create_graph=True
            )
218 219 220 221
        else:
            p_grads = paddle.static.gradients(p_log_norm, p_natural_params)
    except RuntimeError as e:
        raise TypeError(
222 223 224 225
            "Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q}).".format(
                cls_p=type(p).__name__, cls_q=type(q).__name__
            )
        ) from e
226 227

    kl = q._log_normalizer(*q_natural_params) - p_log_norm
228 229 230
    for p_param, q_param, p_grad in zip(
        p_natural_params, q_natural_params, p_grads
    ):
231 232 233 234 235 236
        term = (q_param - p_param) * p_grad
        kl -= _sum_rightmost(term, len(q.event_shape))

    return kl


237 238 239 240 241
@register_kl(LogNormal, LogNormal)
def _kl_lognormal_lognormal(p, q):
    return p._base.kl_divergence(q._base)


242 243
def _sum_rightmost(value, n):
    return value.sum(list(range(-n, 0))) if n > 0 else value