“629cbdae01feb06fe10e54a7b5fadbc79b7a5629”上不存在“paddle/operators/batch_norm_op.cu”
entry_attr.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

__all__ = ['ProbabilityEntry', 'CountFilterEntry']


class EntryAttr(object):
    """
    Examples:
        .. code-block:: python

            import paddle.fluid as fluid
    """

    def __init__(self):
        self._name = None

T
tangwei12 已提交
31
    def _to_attr(self):
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        """
        Returns the attributes of this parameter.

        Returns:
            Parameter attributes(map): The attributes of this parameter.
        """
        raise NotImplementedError("EntryAttr is base class")


class ProbabilityEntry(EntryAttr):
    def __init__(self, probability):
        super(EntryAttr, self).__init__()

        if not isinstance(probability, float):
            raise ValueError("probability must be a float in (0,1)")

        if probability <= 0 or probability >= 1:
            raise ValueError("probability must be a float in (0,1)")

        self._name = "probability_entry"
        self._probability = probability

T
tangwei12 已提交
54
    def _to_attr(self):
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        return ":".join([self._name, str(self._probability)])


class CountFilterEntry(EntryAttr):
    def __init__(self, count_filter):
        super(EntryAttr, self).__init__()

        if not isinstance(count_filter, int):
            raise ValueError(
                "count_filter must be a valid integer greater than 0")

        if count_filter < 0:
            raise ValueError(
                "count_filter must be a valid integer greater or equal than 0")

        self._name = "count_filter_entry"
        self._count_filter = count_filter

T
tangwei12 已提交
73
    def _to_attr(self):
74
        return ":".join([self._name, str(self._count_filter)])