# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from . import control_flow from . import tensor from . import ops from . import nn import math import numpy as np import warnings from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype __all__ = ['Uniform', 'Normal', 'Categorical', 'MultivariateNormalDiag'] class Distribution(object): """ Distribution is the abstract base class for probability distributions. """ def sample(self): """Sampling from the distribution.""" raise NotImplementedError def entropy(self): """The entropy of the distribution.""" raise NotImplementedError def kl_divergence(self, other): """The KL-divergence between self distributions and other.""" raise NotImplementedError def log_prob(self, value): """Log probability density/mass function.""" raise NotImplementedError def _validate_args(self, *args): """ Argument validation for distribution args Args: value (float, list, numpy.ndarray, Variable) Raises ValueError: if one argument is Variable, all arguments should be Variable """ is_variable = False is_number = False for arg in args: if isinstance(arg, tensor.Variable): is_variable = True else: is_number = True if is_variable and is_number: raise ValueError( 'if one argument is Variable, all arguments should be Variable') return is_variable def _to_variable(self, *args): """ Argument convert args to Variable Args: value (float, list, numpy.ndarray, Variable) Returns: Variable of args. """ numpy_args = [] variable_args = [] tmp = 0. for arg in args: valid_arg = False for cls in [float, list, np.ndarray, tensor.Variable]: if isinstance(arg, cls): valid_arg = True break assert valid_arg, "type of input args must be float, list, numpy.ndarray or Variable." if isinstance(arg, float): arg = np.zeros(1) + arg arg_np = np.array(arg) arg_dtype = arg_np.dtype if str(arg_dtype) not in ['float32']: warnings.warn( "data type of argument only support float32, your argument will be convert to float32." ) arg_np = arg_np.astype('float32') tmp = tmp + arg_np numpy_args.append(arg_np) dtype = tmp.dtype for arg in numpy_args: arg_broadcasted, _ = np.broadcast_arrays(arg, tmp) arg_variable = tensor.create_tensor(dtype=dtype) tensor.assign(arg_broadcasted, arg_variable) variable_args.append(arg_variable) return tuple(variable_args) class Uniform(Distribution): r"""Uniform distribution with `low` and `high` parameters. Mathematical Details The probability density function (pdf) is, .. math:: pdf(x; a, b) = \\frac{1}{Z}, \ a <=x