unique_name.py 3.0 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Yu Yang 已提交
17
import collections
S
rename  
sneaxiy 已提交
18
from .wrapped_decorator import signature_safe_contextmanager
19
import six
Y
Yu Yang 已提交
20 21
import sys

Y
yuyang18 已提交
22
__all__ = ['generate', 'switch', 'guard']
Y
Yu Yang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56


class UniqueNameGenerator(object):
    """
    Generate unique name with prefix.

    Args:
        prefix(str): The generated name prefix. All generated name will be
                     started with this prefix.
    """

    def __init__(self, prefix=None):
        self.ids = collections.defaultdict(int)
        if prefix is None:
            prefix = ""
        self.prefix = prefix

    def __call__(self, key):
        """
        Generate unique names with prefix

        Args:
            key(str): The key of return string.

        Returns(str): A unique string with the prefix
        """
        tmp = self.ids[key]
        self.ids[key] += 1
        return self.prefix + "_".join([key, str(tmp)])


generator = UniqueNameGenerator()


Y
Yu Yang 已提交
57 58
def generate(key):
    return generator(key)
Y
Yu Yang 已提交
59 60


61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
# FIXME(zjl): The previous naming rule in static graph would
# cause memory leak in dygraph mode. It is because the previous
# nameing rule would use `conv_0.tmp` as the key, and in dygraph
# mode, `conv_i` increases as batch increases. Thus, keys would
# increase in a way like `conv_0.tmp`, `conv_1.tmp`, .... 
# Not find a better way to fix this bug in dygraph mode. In TF,
# variable name is meaningless in eager execution mode, and in
# PyTorch, there is no variable name at all. Maybe we should
# discard variable name in dygraph mode.
#
# Another concern is that save/load inference. Usually, user
# would save model in static graph mode, and load it in dygraph
# mode. Therefore, we keep the variable name of Parameter currently.
# 
# Please fix me if a better method is found.        
def generate_with_ignorable_key(key):
    from .framework import in_dygraph_mode
    if in_dygraph_mode():
        key = "tmp"

    return generator(key)


Y
Yu Yang 已提交
84 85 86 87 88 89 90 91 92 93
def switch(new_generator=None):
    global generator
    old = generator
    if new_generator is None:
        generator = UniqueNameGenerator()
    else:
        generator = new_generator
    return old


S
rename  
sneaxiy 已提交
94
@signature_safe_contextmanager
Y
Yu Yang 已提交
95
def guard(new_generator=None):
96
    if isinstance(new_generator, six.string_types):
Y
Yu Yang 已提交
97
        new_generator = UniqueNameGenerator(new_generator)
M
minqiyang 已提交
98 99
    elif isinstance(new_generator, six.binary_type):
        new_generator = UniqueNameGenerator(new_generator.decode())
Y
Yu Yang 已提交
100 101 102
    old = switch(new_generator)
    yield
    switch(old)