policy_distribution.py 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl.layers as layers
16
from abc import ABCMeta, abstractmethod
17 18 19 20 21 22 23 24 25
from paddle.fluid.framework import Variable
from parl.layers import common_functions as comf
from paddle.fluid.framework import convert_np_dtype_to_dtype_


class PolicyDistribution(object):
    __metaclass__ = ABCMeta

    def __init__(self, dist):
H
Haonan 已提交
26 27 28 29 30
        """
        self.dist represents the quantities that characterize the distribution.
        For example, for a Normal distribution, this can be a tuple of (mean, std).
        The actual form of self.dist is defined by the user.
        """
31 32 33 34 35 36 37 38 39
        self.dist = dist

    @abstractmethod
    def __call__(self):
        """
        Implement __call__ to sample an instance.
        """
        pass

H
Haonan 已提交
40 41
    @property
    @abstractmethod
42 43 44 45 46 47 48
    def dim(self):
        """
        For discrete policies, this function returns the number of actions.
        For continuous policies, this function returns the action vector length.
        For sequential policies (e.g., sentences), this function returns the number
        of choices at each step.
        """
H
Haonan 已提交
49
        pass
50

H
Haonan 已提交
51 52 53 54 55 56
    def add_uniform_exploration(self, rate):
        """
        Given a uniform exploration rate, this function modifies the distribution.
        The rate could be a floating number of a Variable.
        """
        return NotImplementedError()
57 58 59 60 61 62 63 64 65 66 67 68

    def loglikelihood(self, action):
        """
        Given an action, this function returns the log likelihood of this action under
        the current distribution.
        """
        raise NotImplementedError()


class CategoricalDistribution(PolicyDistribution):
    def __init__(self, dist):
        super(CategoricalDistribution, self).__init__(dist)
H
Haonan 已提交
69
        assert isinstance(dist, Variable)
70 71

    def __call__(self):
72
        return layers.sampling_id(self.dist)
73

H
Haonan 已提交
74 75 76 77 78 79 80 81 82 83
    @property
    def dim(self):
        assert len(self.dist.shape) == 2
        return self.dist.shape[1]

    def add_uniform_exploration(self, rate):
        if not (isinstance(rate, float) and rate == 0):
            self.dist = self.dist * (1 - rate) + \
                   1 / float(self.dim) * rate

84
    def loglikelihood(self, action):
H
Haonan 已提交
85 86 87
        assert isinstance(action, Variable)
        assert action.dtype == convert_np_dtype_to_dtype_("int") \
            or action.dtype == convert_np_dtype_to_dtype_("int64")
88 89 90 91 92 93 94
        return 0 - layers.cross_entropy(input=self.dist, label=action)


class Deterministic(PolicyDistribution):
    def __init__(self, dist):
        super(Deterministic, self).__init__(dist)
        ## For deterministic action, we only support continuous ones
H
Haonan 已提交
95
        assert isinstance(dist, Variable)
96 97 98
        assert dist.dtype == convert_np_dtype_to_dtype_("float32") \
            or dist.dtype == convert_np_dtype_to_dtype_("float64")

H
Haonan 已提交
99 100 101 102 103
    @property
    def dim(self):
        assert len(self.dist.shape) == 2
        return self.dist.shape[1]

104 105 106 107
    def __call__(self):
        return self.dist


H
Haonan 已提交
108
def q_categorical_distribution(q_value):
109 110
    """
    Generate a PolicyDistribution object given a Q value.
H
Haonan 已提交
111
    We construct a one-hot distribution according to the Q value.
112 113 114 115
    """
    assert len(q_value.shape) == 2, "[batch_size, num_actions]"
    max_id = comf.argmax_layer(q_value)
    prob = layers.cast(
B
Bo Zhou 已提交
116
        x=layers.one_hot(input=max_id, depth=q_value.shape[-1]),
117 118
        dtype="float32")
    return CategoricalDistribution(prob)