unique_name.py 7.6 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yu Yang 已提交
17
import collections
S
rename  
sneaxiy 已提交
18
from .wrapped_decorator import signature_safe_contextmanager
19
import six
Y
Yu Yang 已提交
20 21
import sys

Y
yuyang18 已提交
22
__all__ = ['generate', 'switch', 'guard']
Y
Yu Yang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53


class UniqueNameGenerator(object):
    """
    Generate unique name with prefix.

    Args:
        prefix(str): The generated name prefix. All generated name will be
                     started with this prefix.
    """

    def __init__(self, prefix=None):
        self.ids = collections.defaultdict(int)
        if prefix is None:
            prefix = ""
        self.prefix = prefix

    def __call__(self, key):
        """
        Generate unique names with prefix

        Args:
            key(str): The key of return string.

        Returns(str): A unique string with the prefix
        """
        tmp = self.ids[key]
        self.ids[key] += 1
        return self.prefix + "_".join([key, str(tmp)])


H
hong 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
class DygraphParameterNameChecker(object):
    """
    Check whether the name of parameter is used.
    """

    def __init__(self):
        self._name_set = set()

    def __call__(self, name):
        '''
        Check whether the name is used. If not used, insert into the _name_set.

        Args:
            name(str): The name of parameter to check.

        Returns(bool): If the name is in name_set,  return True; Otherwise, return False.

        '''
        if name in self._name_set:
            return True
        else:
            self._name_set.add(name)
            return False


dygraph_parameter_name_checker = DygraphParameterNameChecker()

Y
Yu Yang 已提交
81 82 83
generator = UniqueNameGenerator()


Y
Yu Yang 已提交
84
def generate(key):
85
    """
86 87 88
    Generate unique name with prefix key. Currently, Paddle distinguishes the
    names of the same key by numbering it from zero. For example, when key=fc,
    it continuously generates fc_0, fc_1, fc_2, etc.
89

90 91
    Args: 
        key(str): The prefix of generated name.
92 93 94 95

    Returns: 
        str: A unique string with the prefix key.

96 97
    Examples: 

98 99 100 101 102
        .. code-block:: python

            import paddle.fluid as fluid
            name1 = fluid.unique_name.generate('fc')
            name2 = fluid.unique_name.generate('fc')
103
            print(name1, name2) # fc_0, fc_1
104
    """
Y
Yu Yang 已提交
105
    return generator(key)
Y
Yu Yang 已提交
106 107


108 109
# FIXME(zjl): The previous naming rule in static graph would
# cause memory leak in dygraph mode. It is because the previous
Z
Zeng Jinle 已提交
110
# naming rule would use `conv_0.tmp` as the key, and in dygraph
111 112 113 114 115 116 117
# mode, `conv_i` increases as batch increases. Thus, keys would
# increase in a way like `conv_0.tmp`, `conv_1.tmp`, .... 
# Not find a better way to fix this bug in dygraph mode. In TF,
# variable name is meaningless in eager execution mode, and in
# PyTorch, there is no variable name at all. Maybe we should
# discard variable name in dygraph mode.
#
Z
Zeng Jinle 已提交
118
# Another concern is that save/load interfaces. Usually, user
119 120 121
# would save model in static graph mode, and load it in dygraph
# mode. Therefore, we keep the variable name of Parameter currently.
# 
122 123 124 125
# Please fix me if a better method is found.    
# 
# NOTE(zhiqiu): use c++ unique_name_generator in dygraph mode, 
# in order to keep name consistency.
126
def generate_with_ignorable_key(key):
127
    from .framework import in_dygraph_mode, _dygraph_tracer
128
    if in_dygraph_mode():
129
        return _dygraph_tracer()._generate_unique_name()
130 131 132 133

    return generator(key)


H
hong 已提交
134
def switch(new_generator=None, new_para_name_checker=None):
135
    """
136 137 138 139
    Switch the namespace of in current context to a new namespace. Though
    :code:`switch()` and :code:`guard()` can both change namespace, 
    :code:`guard()` is recommended since it can manage the context better 
    together with :code:`with` statement.
140

141 142 143 144
    Args: 
        new_generator(UniqueNameGenerator, optional): A new UniqueNameGenerator, not
            required normally. Default is None, which means switch to a new anonymous
            namespace.
H
hong 已提交
145 146 147
        new_para_name_checker(DygraphParameterNameChecker, optional): A new DygraphParameterNameChecker,
            not required normally. Default is None, which means  switch to a new parameter name 
            checker.
148 149 150

    Returns: 
        UniqueNameGenerator: The previous UniqueNameGenerator.
H
hong 已提交
151
        DygraphParameterNameChecker: The previous DygraphParameterNameChecker
152

153 154
    Examples: 

155 156 157 158 159
        .. code-block:: python

            import paddle.fluid as fluid
            name1 = fluid.unique_name.generate('fc')
            name2 = fluid.unique_name.generate('fc')
160
            print(name1, name2) # fc_0, fc_1
161

H
hong 已提交
162
            pre_generator, pre_dygraph_name_checker = fluid.unique_name.switch() # switch to a new anonymous namespace.
163
            name2 = fluid.unique_name.generate('fc')
164 165
            print(name2) # fc_0

H
hong 已提交
166
            fluid.unique_name.switch(pre_generator, pre_dygraph_name_checker) # switch back to pre_generator.
167 168 169
            name3 = fluid.unique_name.generate('fc')
            print(name3) # fc_2, since pre_generator has generated fc_0, fc_1.

170
    """
Y
Yu Yang 已提交
171
    global generator
H
hong 已提交
172 173 174
    old_generator = generator
    global dygraph_parameter_name_checker
    old_para_name_checker = dygraph_parameter_name_checker
Y
Yu Yang 已提交
175 176 177 178
    if new_generator is None:
        generator = UniqueNameGenerator()
    else:
        generator = new_generator
H
hong 已提交
179 180 181 182 183 184

    if new_para_name_checker is None:
        dygraph_parameter_name_checker = DygraphParameterNameChecker()
    else:
        dygraph_parameter_name_checker = new_para_name_checker
    return old_generator, old_para_name_checker
Y
Yu Yang 已提交
185 186


S
rename  
sneaxiy 已提交
187
@signature_safe_contextmanager
Y
Yu Yang 已提交
188
def guard(new_generator=None):
189
    """
190 191 192 193 194 195 196 197 198
    Change the namespace of unique name with :code:`with` statement. After calling it,
    a new namespace in the context of :code:`with` will be created, and it will number
    names from zero again when calling :code:`generate()` with same key.

    Args: 
        new_generator(str|bytes, optional): New name of global namespace. Note that str
            in Python2 was spilted into str and bytes in Python3, so here are two 
            types. Default is None. If not None, new_generator will be added into 
            the prefix of unique name generated by :code:`generate()`.
199
    
200 201 202 203
    Returns:
        None.

    Examples: 
204 205 206 207 208 209 210 211

        .. code-block:: python

            import paddle.fluid as fluid
            with fluid.unique_name.guard():
              name_1 = fluid.unique_name.generate('fc')
            with fluid.unique_name.guard():
              name_2 = fluid.unique_name.generate('fc')
212
            print(name_1, name_2) # fc_0, fc_0
213 214 215 216

            with fluid.unique_name.guard('A'):
              name_1 = fluid.unique_name.generate('fc')
            with fluid.unique_name.guard('B'):
217 218
              name_2 = fluid.unique_name.generate('fc') 
            print(name_1, name_2) # Afc_0, Bfc_0
219
    """
220
    if isinstance(new_generator, six.string_types):
Y
Yu Yang 已提交
221
        new_generator = UniqueNameGenerator(new_generator)
M
minqiyang 已提交
222 223
    elif isinstance(new_generator, six.binary_type):
        new_generator = UniqueNameGenerator(new_generator.decode())
H
hong 已提交
224 225

    old_generator, old_para_name_checker = switch(new_generator)
226 227 228 229
    try:
        yield
    finally:
        switch(old_generator, old_para_name_checker)