unique_name.py 7.6 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yu Yang 已提交
17
import collections
S
rename  
sneaxiy 已提交
18
from .wrapped_decorator import signature_safe_contextmanager
19
import six
Y
Yu Yang 已提交
20 21
import sys

Y
yuyang18 已提交
22
__all__ = ['generate', 'switch', 'guard']
Y
Yu Yang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53


class UniqueNameGenerator(object):
    """
    Generate unique name with prefix.

    Args:
        prefix(str): The generated name prefix. All generated name will be
                     started with this prefix.
    """

    def __init__(self, prefix=None):
        self.ids = collections.defaultdict(int)
        if prefix is None:
            prefix = ""
        self.prefix = prefix

    def __call__(self, key):
        """
        Generate unique names with prefix

        Args:
            key(str): The key of return string.

        Returns(str): A unique string with the prefix
        """
        tmp = self.ids[key]
        self.ids[key] += 1
        return self.prefix + "_".join([key, str(tmp)])


H
hong 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
class DygraphParameterNameChecker(object):
    """
    Check whether the name of parameter is used.
    """

    def __init__(self):
        self._name_set = set()

    def __call__(self, name):
        '''
        Check whether the name is used. If not used, insert into the _name_set.

        Args:
            name(str): The name of parameter to check.

        Returns(bool): If the name is in name_set,  return True; Otherwise, return False.

        '''
        if name in self._name_set:
            return True
        else:
            self._name_set.add(name)
            return False


dygraph_parameter_name_checker = DygraphParameterNameChecker()

Y
Yu Yang 已提交
81 82 83
generator = UniqueNameGenerator()


Y
Yu Yang 已提交
84
def generate(key):
85
    """
86 87 88
    Generate unique name with prefix key. Currently, Paddle distinguishes the
    names of the same key by numbering it from zero. For example, when key=fc,
    it continuously generates fc_0, fc_1, fc_2, etc.
89

90 91
    Args: 
        key(str): The prefix of generated name.
92 93 94 95

    Returns: 
        str: A unique string with the prefix key.

96 97
    Examples: 

98 99
        .. code-block:: python

100 101 102
            import paddle
            name1 = paddle.utils.unique_name.generate('fc')
            name2 = paddle.utils.unique_name.generate('fc')
103
            print(name1, name2) # fc_0, fc_1
104
    """
Y
Yu Yang 已提交
105
    return generator(key)
Y
Yu Yang 已提交
106 107


108 109
# FIXME(zjl): The previous naming rule in static graph would
# cause memory leak in dygraph mode. It is because the previous
Z
Zeng Jinle 已提交
110
# naming rule would use `conv_0.tmp` as the key, and in dygraph
111 112 113 114 115 116 117
# mode, `conv_i` increases as batch increases. Thus, keys would
# increase in a way like `conv_0.tmp`, `conv_1.tmp`, .... 
# Not find a better way to fix this bug in dygraph mode. In TF,
# variable name is meaningless in eager execution mode, and in
# PyTorch, there is no variable name at all. Maybe we should
# discard variable name in dygraph mode.
#
Z
Zeng Jinle 已提交
118
# Another concern is that save/load interfaces. Usually, user
119 120 121
# would save model in static graph mode, and load it in dygraph
# mode. Therefore, we keep the variable name of Parameter currently.
# 
122 123 124 125
# Please fix me if a better method is found.    
# 
# NOTE(zhiqiu): use c++ unique_name_generator in dygraph mode, 
# in order to keep name consistency.
126
def generate_with_ignorable_key(key):
J
Jiabin Yang 已提交
127 128
    from .framework import _non_static_mode, _dygraph_tracer
    if _non_static_mode():
129
        return _dygraph_tracer()._generate_unique_name()
130 131 132 133

    return generator(key)


H
hong 已提交
134
def switch(new_generator=None, new_para_name_checker=None):
135
    """
136 137 138 139
    Switch the namespace of in current context to a new namespace. Though
    :code:`switch()` and :code:`guard()` can both change namespace, 
    :code:`guard()` is recommended since it can manage the context better 
    together with :code:`with` statement.
140

141 142 143 144
    Args: 
        new_generator(UniqueNameGenerator, optional): A new UniqueNameGenerator, not
            required normally. Default is None, which means switch to a new anonymous
            namespace.
H
hong 已提交
145 146 147
        new_para_name_checker(DygraphParameterNameChecker, optional): A new DygraphParameterNameChecker,
            not required normally. Default is None, which means  switch to a new parameter name 
            checker.
148 149 150

    Returns: 
        UniqueNameGenerator: The previous UniqueNameGenerator.
H
hong 已提交
151
        DygraphParameterNameChecker: The previous DygraphParameterNameChecker
152

153 154
    Examples: 

155 156
        .. code-block:: python

157 158 159
            import paddle
            name1 = paddle.utils.unique_name.generate('fc')
            name2 = paddle.utils.unique_name.generate('fc')
160
            print(name1, name2) # fc_0, fc_1
161

162 163
            pre_generator, pre_dygraph_name_checker = paddle.utils.unique_name.switch() # switch to a new anonymous namespace.
            name2 = paddle.utils.unique_name.generate('fc')
164 165
            print(name2) # fc_0

166 167
            paddle.utils.unique_name.switch(pre_generator, pre_dygraph_name_checker) # switch back to pre_generator.
            name3 = paddle.utils.unique_name.generate('fc')
168
            print(name3) # fc_2, since pre_generator has generated fc_0, fc_1.
169
    """
Y
Yu Yang 已提交
170
    global generator
H
hong 已提交
171 172 173
    old_generator = generator
    global dygraph_parameter_name_checker
    old_para_name_checker = dygraph_parameter_name_checker
Y
Yu Yang 已提交
174 175 176 177
    if new_generator is None:
        generator = UniqueNameGenerator()
    else:
        generator = new_generator
H
hong 已提交
178 179 180 181 182 183

    if new_para_name_checker is None:
        dygraph_parameter_name_checker = DygraphParameterNameChecker()
    else:
        dygraph_parameter_name_checker = new_para_name_checker
    return old_generator, old_para_name_checker
Y
Yu Yang 已提交
184 185


S
rename  
sneaxiy 已提交
186
@signature_safe_contextmanager
Y
Yu Yang 已提交
187
def guard(new_generator=None):
188
    """
189 190 191 192 193 194 195 196 197
    Change the namespace of unique name with :code:`with` statement. After calling it,
    a new namespace in the context of :code:`with` will be created, and it will number
    names from zero again when calling :code:`generate()` with same key.

    Args: 
        new_generator(str|bytes, optional): New name of global namespace. Note that str
            in Python2 was spilted into str and bytes in Python3, so here are two 
            types. Default is None. If not None, new_generator will be added into 
            the prefix of unique name generated by :code:`generate()`.
198
    
199 200 201 202
    Returns:
        None.

    Examples: 
203 204 205

        .. code-block:: python

206 207 208 209 210
            import paddle
            with paddle.utils.unique_name.guard():
                name_1 = paddle.utils.unique_name.generate('fc')
            with paddle.utils.unique_name.guard():
                name_2 = paddle.utils.unique_name.generate('fc')
211
            print(name_1, name_2) # fc_0, fc_0
212

213 214 215 216
            with paddle.utils.unique_name.guard('A'):
                name_1 = paddle.utils.unique_name.generate('fc')
            with paddle.utils.unique_name.guard('B'):
                name_2 = paddle.utils.unique_name.generate('fc')
217
            print(name_1, name_2) # Afc_0, Bfc_0
218
    """
219
    if isinstance(new_generator, six.string_types):
Y
Yu Yang 已提交
220
        new_generator = UniqueNameGenerator(new_generator)
M
minqiyang 已提交
221 222
    elif isinstance(new_generator, six.binary_type):
        new_generator = UniqueNameGenerator(new_generator.decode())
H
hong 已提交
223 224

    old_generator, old_para_name_checker = switch(new_generator)
225 226 227 228
    try:
        yield
    finally:
        switch(old_generator, old_para_name_checker)