unique_name.py 7.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
S
rename  
sneaxiy 已提交
16
from .wrapped_decorator import signature_safe_contextmanager
Y
Yu Yang 已提交
17

Y
yuyang18 已提交
18
__all__ = ['generate', 'switch', 'guard']
Y
Yu Yang 已提交
19 20


21
class UniqueNameGenerator:
Y
Yu Yang 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    """
    Generate unique name with prefix.

    Args:
        prefix(str): The generated name prefix. All generated name will be
                     started with this prefix.
    """

    def __init__(self, prefix=None):
        self.ids = collections.defaultdict(int)
        if prefix is None:
            prefix = ""
        self.prefix = prefix

    def __call__(self, key):
        """
        Generate unique names with prefix

        Args:
            key(str): The key of return string.

        Returns(str): A unique string with the prefix
        """
        tmp = self.ids[key]
        self.ids[key] += 1
        return self.prefix + "_".join([key, str(tmp)])


50
class DygraphParameterNameChecker:
H
hong 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    """
    Check whether the name of parameter is used.
    """

    def __init__(self):
        self._name_set = set()

    def __call__(self, name):
        '''
        Check whether the name is used. If not used, insert into the _name_set.

        Args:
            name(str): The name of parameter to check.

        Returns(bool): If the name is in name_set,  return True; Otherwise, return False.

        '''
        if name in self._name_set:
            return True
        else:
            self._name_set.add(name)
            return False


dygraph_parameter_name_checker = DygraphParameterNameChecker()

Y
Yu Yang 已提交
77 78 79
generator = UniqueNameGenerator()


Y
Yu Yang 已提交
80
def generate(key):
81
    """
82 83 84
    Generate unique name with prefix key. Currently, Paddle distinguishes the
    names of the same key by numbering it from zero. For example, when key=fc,
    it continuously generates fc_0, fc_1, fc_2, etc.
85

86
    Args:
87
        key(str): The prefix of generated name.
88

89
    Returns:
90 91
        str: A unique string with the prefix key.

92
    Examples:
93

94 95
        .. code-block:: python

96 97 98
            import paddle
            name1 = paddle.utils.unique_name.generate('fc')
            name2 = paddle.utils.unique_name.generate('fc')
99
            print(name1, name2) # fc_0, fc_1
100
    """
Y
Yu Yang 已提交
101
    return generator(key)
Y
Yu Yang 已提交
102 103


104 105
# FIXME(zjl): The previous naming rule in static graph would
# cause memory leak in dygraph mode. It is because the previous
Z
Zeng Jinle 已提交
106
# naming rule would use `conv_0.tmp` as the key, and in dygraph
107
# mode, `conv_i` increases as batch increases. Thus, keys would
108
# increase in a way like `conv_0.tmp`, `conv_1.tmp`, ....
109 110 111 112 113
# Not find a better way to fix this bug in dygraph mode. In TF,
# variable name is meaningless in eager execution mode, and in
# PyTorch, there is no variable name at all. Maybe we should
# discard variable name in dygraph mode.
#
Z
Zeng Jinle 已提交
114
# Another concern is that save/load interfaces. Usually, user
115 116
# would save model in static graph mode, and load it in dygraph
# mode. Therefore, we keep the variable name of Parameter currently.
117 118 119 120
#
# Please fix me if a better method is found.
#
# NOTE(zhiqiu): use c++ unique_name_generator in dygraph mode,
121
# in order to keep name consistency.
122
def generate_with_ignorable_key(key):
J
Jiabin Yang 已提交
123
    from .framework import _non_static_mode, _dygraph_tracer
124

J
Jiabin Yang 已提交
125
    if _non_static_mode():
126
        return _dygraph_tracer()._generate_unique_name()
127 128 129 130

    return generator(key)


H
hong 已提交
131
def switch(new_generator=None, new_para_name_checker=None):
132
    """
133
    Switch the namespace of in current context to a new namespace. Though
134 135
    :code:`switch()` and :code:`guard()` can both change namespace,
    :code:`guard()` is recommended since it can manage the context better
136
    together with :code:`with` statement.
137

138
    Args:
139 140 141
        new_generator(UniqueNameGenerator, optional): A new UniqueNameGenerator, not
            required normally. Default is None, which means switch to a new anonymous
            namespace.
H
hong 已提交
142
        new_para_name_checker(DygraphParameterNameChecker, optional): A new DygraphParameterNameChecker,
143
            not required normally. Default is None, which means  switch to a new parameter name
H
hong 已提交
144
            checker.
145

146
    Returns:
147
        UniqueNameGenerator: The previous UniqueNameGenerator.
H
hong 已提交
148
        DygraphParameterNameChecker: The previous DygraphParameterNameChecker
149

150
    Examples:
151

152 153
        .. code-block:: python

154 155 156
            import paddle
            name1 = paddle.utils.unique_name.generate('fc')
            name2 = paddle.utils.unique_name.generate('fc')
157
            print(name1, name2) # fc_0, fc_1
158

159 160
            pre_generator, pre_dygraph_name_checker = paddle.utils.unique_name.switch() # switch to a new anonymous namespace.
            name2 = paddle.utils.unique_name.generate('fc')
161 162
            print(name2) # fc_0

163 164
            paddle.utils.unique_name.switch(pre_generator, pre_dygraph_name_checker) # switch back to pre_generator.
            name3 = paddle.utils.unique_name.generate('fc')
165
            print(name3) # fc_2, since pre_generator has generated fc_0, fc_1.
166
    """
Y
Yu Yang 已提交
167
    global generator
H
hong 已提交
168 169 170
    old_generator = generator
    global dygraph_parameter_name_checker
    old_para_name_checker = dygraph_parameter_name_checker
Y
Yu Yang 已提交
171 172 173 174
    if new_generator is None:
        generator = UniqueNameGenerator()
    else:
        generator = new_generator
H
hong 已提交
175 176 177 178 179 180

    if new_para_name_checker is None:
        dygraph_parameter_name_checker = DygraphParameterNameChecker()
    else:
        dygraph_parameter_name_checker = new_para_name_checker
    return old_generator, old_para_name_checker
Y
Yu Yang 已提交
181 182


S
rename  
sneaxiy 已提交
183
@signature_safe_contextmanager
Y
Yu Yang 已提交
184
def guard(new_generator=None):
185
    """
186 187 188 189
    Change the namespace of unique name with :code:`with` statement. After calling it,
    a new namespace in the context of :code:`with` will be created, and it will number
    names from zero again when calling :code:`generate()` with same key.

190
    Args:
191
        new_generator(str|bytes, optional): New name of global namespace. Note that str
192 193
            in Python2 was spilted into str and bytes in Python3, so here are two
            types. Default is None. If not None, new_generator will be added into
194
            the prefix of unique name generated by :code:`generate()`.
195

196 197 198
    Returns:
        None.

199
    Examples:
200 201 202

        .. code-block:: python

203 204 205 206 207
            import paddle
            with paddle.utils.unique_name.guard():
                name_1 = paddle.utils.unique_name.generate('fc')
            with paddle.utils.unique_name.guard():
                name_2 = paddle.utils.unique_name.generate('fc')
208
            print(name_1, name_2) # fc_0, fc_0
209

210 211 212 213
            with paddle.utils.unique_name.guard('A'):
                name_1 = paddle.utils.unique_name.generate('fc')
            with paddle.utils.unique_name.guard('B'):
                name_2 = paddle.utils.unique_name.generate('fc')
214
            print(name_1, name_2) # Afc_0, Bfc_0
215
    """
216
    if isinstance(new_generator, str):
Y
Yu Yang 已提交
217
        new_generator = UniqueNameGenerator(new_generator)
218
    elif isinstance(new_generator, bytes):
M
minqiyang 已提交
219
        new_generator = UniqueNameGenerator(new_generator.decode())
H
hong 已提交
220 221

    old_generator, old_para_name_checker = switch(new_generator)
222 223 224 225
    try:
        yield
    finally:
        switch(old_generator, old_para_name_checker)