unique_name.py 7.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yu Yang 已提交
17
import collections
S
rename  
sneaxiy 已提交
18
from .wrapped_decorator import signature_safe_contextmanager
19
import six
Y
Yu Yang 已提交
20 21
import sys

Y
yuyang18 已提交
22
__all__ = ['generate', 'switch', 'guard']
Y
Yu Yang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53


class UniqueNameGenerator(object):
    """
    Generate unique name with prefix.

    Args:
        prefix(str): The generated name prefix. All generated name will be
                     started with this prefix.
    """

    def __init__(self, prefix=None):
        self.ids = collections.defaultdict(int)
        if prefix is None:
            prefix = ""
        self.prefix = prefix

    def __call__(self, key):
        """
        Generate unique names with prefix

        Args:
            key(str): The key of return string.

        Returns(str): A unique string with the prefix
        """
        tmp = self.ids[key]
        self.ids[key] += 1
        return self.prefix + "_".join([key, str(tmp)])


54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
class DygraphParameterNameChecker(object):
    """
    Check whether the name of parameter is used.
    """

    def __init__(self):
        self._name_set = set()

    def __call__(self, name):
        '''
        Check whether the name is used. If not used, insert into the _name_set.

        Args:
            name(str): The name of parameter to check.

        Returns(bool): If the name is in name_set,  return True; Otherwise, return False.

        '''
        if name in self._name_set:
            return True
        else:
            self._name_set.add(name)
            return False


dygraph_parameter_name_checker = DygraphParameterNameChecker()

Y
Yu Yang 已提交
81 82 83
generator = UniqueNameGenerator()


Y
Yu Yang 已提交
84
def generate(key):
85
    """
86 87 88
    Generate unique name with prefix key. Currently, Paddle distinguishes the
    names of the same key by numbering it from zero. For example, when key=fc,
    it continuously generates fc_0, fc_1, fc_2, etc.
89

90 91
    Args: 
        key(str): The prefix of generated name.
92 93 94 95

    Returns: 
        str: A unique string with the prefix key.

96 97
    Examples: 

98 99 100 101 102
        .. code-block:: python

            import paddle.fluid as fluid
            name1 = fluid.unique_name.generate('fc')
            name2 = fluid.unique_name.generate('fc')
103
            print(name1, name2) # fc_0, fc_1
104
    """
Y
Yu Yang 已提交
105
    return generator(key)
Y
Yu Yang 已提交
106 107


108 109
# FIXME(zjl): The previous naming rule in static graph would
# cause memory leak in dygraph mode. It is because the previous
Z
Zeng Jinle 已提交
110
# naming rule would use `conv_0.tmp` as the key, and in dygraph
111 112 113 114 115 116 117
# mode, `conv_i` increases as batch increases. Thus, keys would
# increase in a way like `conv_0.tmp`, `conv_1.tmp`, .... 
# Not find a better way to fix this bug in dygraph mode. In TF,
# variable name is meaningless in eager execution mode, and in
# PyTorch, there is no variable name at all. Maybe we should
# discard variable name in dygraph mode.
#
Z
Zeng Jinle 已提交
118
# Another concern is that save/load interfaces. Usually, user
119 120 121 122 123 124 125 126 127 128 129 130
# would save model in static graph mode, and load it in dygraph
# mode. Therefore, we keep the variable name of Parameter currently.
# 
# Please fix me if a better method is found.        
def generate_with_ignorable_key(key):
    from .framework import in_dygraph_mode
    if in_dygraph_mode():
        key = "tmp"

    return generator(key)


131
def switch(new_generator=None, new_para_name_checker=None):
132
    """
133 134 135 136
    Switch the namespace of in current context to a new namespace. Though
    :code:`switch()` and :code:`guard()` can both change namespace, 
    :code:`guard()` is recommended since it can manage the context better 
    together with :code:`with` statement.
137

138 139 140 141
    Args: 
        new_generator(UniqueNameGenerator, optional): A new UniqueNameGenerator, not
            required normally. Default is None, which means switch to a new anonymous
            namespace.
142 143 144
        new_para_name_checker(DygraphParameterNameChecker, optional): A new DygraphParameterNameChecker,
            not required normally. Default is None, which means  switch to a new parameter name 
            checker.
145 146 147

    Returns: 
        UniqueNameGenerator: The previous UniqueNameGenerator.
148
        DygraphParameterNameChecker: The previous DygraphParameterNameChecker
149

150 151
    Examples: 

152 153 154 155 156
        .. code-block:: python

            import paddle.fluid as fluid
            name1 = fluid.unique_name.generate('fc')
            name2 = fluid.unique_name.generate('fc')
157
            print(name1, name2) # fc_0, fc_1
158

159
            pre_generator, pre_dygraph_name_checker = fluid.unique_name.switch() # switch to a new anonymous namespace.
160
            name2 = fluid.unique_name.generate('fc')
161 162
            print(name2) # fc_0

163
            fluid.unique_name.switch(pre_generator, pre_dygraph_name_checker) # switch back to pre_generator.
164 165 166
            name3 = fluid.unique_name.generate('fc')
            print(name3) # fc_2, since pre_generator has generated fc_0, fc_1.

167
    """
Y
Yu Yang 已提交
168
    global generator
169 170 171
    old_generator = generator
    global dygraph_parameter_name_checker
    old_para_name_checker = dygraph_parameter_name_checker
Y
Yu Yang 已提交
172 173 174 175
    if new_generator is None:
        generator = UniqueNameGenerator()
    else:
        generator = new_generator
176 177 178 179 180 181

    if new_para_name_checker is None:
        dygraph_parameter_name_checker = DygraphParameterNameChecker()
    else:
        dygraph_parameter_name_checker = new_para_name_checker
    return old_generator, old_para_name_checker
Y
Yu Yang 已提交
182 183


S
rename  
sneaxiy 已提交
184
@signature_safe_contextmanager
Y
Yu Yang 已提交
185
def guard(new_generator=None):
186
    """
187 188 189 190 191 192 193 194 195
    Change the namespace of unique name with :code:`with` statement. After calling it,
    a new namespace in the context of :code:`with` will be created, and it will number
    names from zero again when calling :code:`generate()` with same key.

    Args: 
        new_generator(str|bytes, optional): New name of global namespace. Note that str
            in Python2 was spilted into str and bytes in Python3, so here are two 
            types. Default is None. If not None, new_generator will be added into 
            the prefix of unique name generated by :code:`generate()`.
196
    
197 198 199 200
    Returns:
        None.

    Examples: 
201 202 203 204 205 206 207 208

        .. code-block:: python

            import paddle.fluid as fluid
            with fluid.unique_name.guard():
              name_1 = fluid.unique_name.generate('fc')
            with fluid.unique_name.guard():
              name_2 = fluid.unique_name.generate('fc')
209
            print(name_1, name_2) # fc_0, fc_0
210 211 212 213

            with fluid.unique_name.guard('A'):
              name_1 = fluid.unique_name.generate('fc')
            with fluid.unique_name.guard('B'):
214 215
              name_2 = fluid.unique_name.generate('fc') 
            print(name_1, name_2) # Afc_0, Bfc_0
216
    """
217
    if isinstance(new_generator, six.string_types):
Y
Yu Yang 已提交
218
        new_generator = UniqueNameGenerator(new_generator)
M
minqiyang 已提交
219 220
    elif isinstance(new_generator, six.binary_type):
        new_generator = UniqueNameGenerator(new_generator.decode())
221 222

    old_generator, old_para_name_checker = switch(new_generator)
Y
Yu Yang 已提交
223
    yield
224
    switch(old_generator, old_para_name_checker)