diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index abe2b114492007ec19f2fcdb09aa173c88badbf5..b3e03f33470810a685dc7bfe29f8da50454b2238 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -49,11 +49,6 @@ PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); namespace paddle { namespace pybind { -static size_t UniqueIntegerGenerator(const std::string &prefix) { - static std::unordered_map> generators; - return generators[prefix].fetch_add(1); -} - bool IsCompiledWithCUDA() { #ifndef PADDLE_WITH_CUDA return false; @@ -410,7 +405,6 @@ All parameter, weight, gradient are variables in Paddle. (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & Executor::Run); - m.def("unique_integer", UniqueIntegerGenerator); m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); m.def("init_devices", &framework::InitDevices); diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 361fb3f5ad9394a5cc1a9927005e7276ee056e90..39d13d3ab5fb8340509e01b0bd1de6f66ce99c21 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -39,6 +39,7 @@ from concurrency import (Go, make_channel, channel_send, channel_recv, import clip from memory_optimization_transpiler import memory_optimize import profiler +import unique_name Tensor = LoDTensor @@ -63,6 +64,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [ 'DistributeTranspiler', 'memory_optimize', 'profiler', + 'unique_name', ] diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index ba27aaa24601bd72bcdbd064242ea2b1c345340c..4da73bb9963cd7870e8496686c9618b4da1728e8 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -16,6 +16,7 @@ from paddle.v2.fluid import framework as framework from . import core import collections import copy +import unique_name __all__ = [ 'append_backward', @@ -391,7 +392,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): for name in op_desc.output_arg_names(): if block.desc.find_var(name.encode("ascii")): - new_name = "%s_%s" % (name, core.unique_integer(name)) + new_name = unique_name.generate(name) op_desc.rename_output(name, new_name) var_map[name] = new_name diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index 1f4618310cbba8ef3fdce6a3beb01876c5074e32..8cc49053337a25d917b85a69a453cf29b1597548 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -15,7 +15,8 @@ import numpy as np import layers -from framework import Program, unique_name, Variable, program_guard +from framework import Program, Variable, program_guard +import unique_name from layer_helper import LayerHelper __all__ = [ @@ -96,7 +97,7 @@ class Evaluator(object): """ state = self.helper.create_variable( - name="_".join([unique_name(self.helper.name), suffix]), + name="_".join([unique_name.generate(self.helper.name), suffix]), persistable=True, dtype=dtype, shape=shape) diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 0f6cb90e27c714d02d00402aeee4e5d718f77502..64441e8fa491dd71101c95e14bedf956eb61ee3e 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -20,6 +20,7 @@ import numpy as np import proto.framework_pb2 as framework_pb2 from . import core +import unique_name __all__ = [ 'Block', @@ -47,20 +48,6 @@ def grad_var_name(var_name): return var_name + GRAD_VAR_SUFFIX -def unique_name(prefix): - """ - Generate unique names with prefix - - Args: - prefix(str): The prefix of return string - - Returns(str): A unique string with the prefix - - """ - uid = core.unique_integer(prefix) # unique during whole process. - return "_".join([prefix, str(uid)]) - - def convert_np_dtype_to_dtype_(np_dtype): """ Convert the data type in numpy to the data type in Paddle @@ -175,7 +162,7 @@ class Variable(object): self.error_clip = error_clip if name is None: - name = Variable._unique_var_name_() + name = unique_name.generate('_generated_var') is_new_var = False self.desc = self.block.desc.find_var(name) @@ -307,12 +294,6 @@ class Variable(object): def type(self): return self.desc.type() - @staticmethod - def _unique_var_name_(): - prefix = "_generated_var" - uid = core.unique_integer(prefix) # unique during whole process. - return "_".join([prefix, str(uid)]) - def set_error_clip(self, error_clip): self.error_clip = error_clip diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index e7abc23f2f1da967ed6fff2e758f3dc6f80d60a8..dc4f992ddc3c349a74d62003befbac923a747c07 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -15,8 +15,8 @@ import copy import itertools -from framework import Variable, Parameter, default_main_program, default_startup_program, \ - unique_name, dtype_is_floating +from framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating +import unique_name from paddle.v2.fluid.initializer import Constant, Xavier from param_attr import ParamAttr, WeightNormParamAttr @@ -27,7 +27,7 @@ class LayerHelper(object): self.layer_type = layer_type name = self.kwargs.get('name', None) if name is None: - self.kwargs['name'] = unique_name(self.layer_type) + self.kwargs['name'] = unique_name.generate(self.layer_type) @property def name(self): @@ -117,17 +117,20 @@ class LayerHelper(object): block=self.startup_program.global_block()): if out is None: out = block.create_var( - name=unique_name(".".join([self.name, 'weight_norm_norm'])), + name=unique_name.generate(".".join( + [self.name, 'weight_norm_norm'])), dtype=dtype, persistable=False) abs_out = block.create_var( - name=unique_name(".".join([self.name, 'weight_norm_abs'])), + name=unique_name.generate(".".join( + [self.name, 'weight_norm_abs'])), dtype=dtype, persistable=False) block.append_op( type='abs', inputs={'X': x}, outputs={'Out': abs_out}) pow_out = block.create_var( - name=unique_name(".".join([self.name, 'weight_norm_pow'])), + name=unique_name.generate(".".join( + [self.name, 'weight_norm_pow'])), dtype=dtype, persistable=False) block.append_op( @@ -136,7 +139,8 @@ class LayerHelper(object): outputs={'Out': pow_out}, attrs={'factor': float(p)}) sum_out = block.create_var( - name=unique_name(".".join([self.name, 'weight_norm_sum'])), + name=unique_name.generate(".".join( + [self.name, 'weight_norm_sum'])), dtype=dtype, persistable=False) block.append_op( @@ -161,7 +165,7 @@ class LayerHelper(object): block=self.startup_program.global_block()): if out is None: out = block.create_var( - name=unique_name(".".join( + name=unique_name.generate(".".join( [self.name, 'weight_norm_reshape'])), dtype=dtype, persistable=False) @@ -178,7 +182,7 @@ class LayerHelper(object): block=self.startup_program.global_block()): if out is None: out = block.create_var( - name=unique_name(".".join( + name=unique_name.generate(".".join( [self.name, 'weight_norm_transpose'])), dtype=dtype, persistable=False) @@ -196,7 +200,8 @@ class LayerHelper(object): """Computes the norm over all dimensions except dim""" if out is None: out = block.create_var( - name=unique_name(".".join([self.name, 'weight_norm_norm'])), + name=unique_name.generate(".".join( + [self.name, 'weight_norm_norm'])), dtype=dtype, persistable=False) if dim is None: @@ -286,7 +291,7 @@ class LayerHelper(object): assert isinstance(attr, ParamAttr) suffix = 'b' if is_bias else 'w' if attr.name is None: - attr.name = unique_name(".".join([self.name, suffix])) + attr.name = unique_name.generate(".".join([self.name, suffix])) if default_initializer is None and attr.initializer is None: if is_bias: @@ -316,7 +321,7 @@ class LayerHelper(object): def create_tmp_variable(self, dtype, stop_gradient=False): return self.main_program.current_block().create_var( - name=unique_name(".".join([self.name, 'tmp'])), + name=unique_name.generate(".".join([self.name, 'tmp'])), dtype=dtype, persistable=False, stop_gradient=stop_gradient) diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 72056cc7cdfb861f4bcdf2afdd3ba62c1a996f22..1bb1aa30ee1019c6f80eb64b6dc20459e7a3073b 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -428,7 +428,8 @@ class StaticRNN(object): raise ValueError( "if init is None, memory at least need shape and batch_ref") parent_block = self.parent_block() - var_name = unique_name("@".join([self.helper.name, "memory_boot"])) + var_name = unique_name.generate("@".join( + [self.helper.name, "memory_boot"])) boot_var = parent_block.create_var( name=var_name, shape=shape, @@ -450,7 +451,7 @@ class StaticRNN(object): return self.memory(init=boot_var) else: pre_mem = self.helper.create_variable( - name=unique_name("@".join([self.helper.name, "mem"])), + name=unique_name.generate("@".join([self.helper.name, "mem"])), dtype=init.dtype, shape=init.shape) self.memories[pre_mem.name] = StaticRNNMemoryLink( @@ -710,7 +711,7 @@ def lod_rank_table(x, level=0): helper = LayerHelper("lod_rank_table", **locals()) table = helper.create_variable( type=core.VarDesc.VarType.LOD_RANK_TABLE, - name=unique_name("lod_rank_table")) + name=unique_name.generate("lod_rank_table")) helper.append_op( type='lod_rank_table', inputs={'X': x}, @@ -808,7 +809,7 @@ def lod_tensor_to_array(x, table): """ helper = LayerHelper("lod_tensor_to_array", **locals()) array = helper.create_variable( - name=unique_name("lod_tensor_to_array"), + name=unique_name.generate("lod_tensor_to_array"), type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, dtype=x.dtype) helper.append_op( @@ -1265,11 +1266,11 @@ class IfElse(object): if id(x) not in self.input_table: parent_block = self.parent_block() out_true = parent_block.create_var( - name=unique_name('ifelse_input' + self.helper.name), + name=unique_name.generate('ifelse_input' + self.helper.name), dtype=x.dtype) out_false = parent_block.create_var( - name=unique_name('ifelse_input' + self.helper.name), + name=unique_name.generate('ifelse_input' + self.helper.name), dtype=x.dtype) parent_block.append_op( type='split_lod_tensor', @@ -1311,7 +1312,8 @@ class IfElse(object): raise TypeError("Each output should be a variable") # create outside tensor outside_out = parent_block.create_var( - name=unique_name("_".join([self.helper.name, 'output'])), + name=unique_name.generate("_".join( + [self.helper.name, 'output'])), dtype=each_out.dtype) out_table.append(outside_out) @@ -1374,7 +1376,7 @@ class DynamicRNN(object): parent_block = self._parent_block_() if self.lod_rank_table is None: self.lod_rank_table = parent_block.create_var( - name=unique_name('lod_rank_table'), + name=unique_name.generate('lod_rank_table'), type=core.VarDesc.VarType.LOD_RANK_TABLE) self.lod_rank_table.stop_gradient = True parent_block.append_op( @@ -1382,7 +1384,8 @@ class DynamicRNN(object): inputs={"X": x}, outputs={"Out": self.lod_rank_table}) self.max_seq_len = parent_block.create_var( - name=unique_name('dynamic_rnn_max_seq_len'), dtype='int64') + name=unique_name.generate('dynamic_rnn_max_seq_len'), + dtype='int64') self.max_seq_len.stop_gradient = False parent_block.append_op( type='max_sequence_len', @@ -1396,7 +1399,7 @@ class DynamicRNN(object): outputs={'Out': self.cond}) input_array = parent_block.create_var( - name=unique_name('dynamic_rnn_input_array'), + name=unique_name.generate('dynamic_rnn_input_array'), type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, dtype=x.dtype) self.input_array.append((input_array, x.dtype)) @@ -1417,7 +1420,7 @@ class DynamicRNN(object): "static_input() must be called after step_input().") parent_block = self._parent_block_() x_reordered = parent_block.create_var( - name=unique_name("dynamic_rnn_static_input_reordered"), + name=unique_name.generate("dynamic_rnn_static_input_reordered"), type=core.VarDesc.VarType.LOD_TENSOR, dtype=x.dtype) parent_block.append_op( @@ -1479,7 +1482,7 @@ class DynamicRNN(object): 'invoked before ' 'memory(init=init, need_reordered=True, ...).') init_reordered = parent_block.create_var( - name=unique_name('dynamic_rnn_mem_init_reordered'), + name=unique_name.generate('dynamic_rnn_mem_init_reordered'), type=core.VarDesc.VarType.LOD_TENSOR, dtype=init.dtype) parent_block.append_op( @@ -1491,7 +1494,7 @@ class DynamicRNN(object): outputs={'Out': [init_reordered]}) init_tensor = init_reordered mem_array = parent_block.create_var( - name=unique_name('dynamic_rnn_mem_array'), + name=unique_name.generate('dynamic_rnn_mem_array'), type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, dtype=init.dtype) parent_block.append_op( @@ -1511,9 +1514,10 @@ class DynamicRNN(object): ) parent_block = self._parent_block_() init = parent_block.create_var( - name=unique_name('mem_init'), dtype=dtype) + name=unique_name.generate('mem_init'), dtype=dtype) arr, dtype = self.input_array[0] - in0 = parent_block.create_var(name=unique_name('in0'), dtype=dtype) + in0 = parent_block.create_var( + name=unique_name.generate('in0'), dtype=dtype) parent_block.append_op( type='read_from_array', inputs={'X': [arr], @@ -1552,7 +1556,7 @@ class DynamicRNN(object): parent_block = self._parent_block_() for each in outputs: outside_array = parent_block.create_var( - name=unique_name("_".join( + name=unique_name.generate("_".join( [self.helper.name, "output_array", each.name])), type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, dtype=each.dtype) diff --git a/python/paddle/v2/fluid/layers/device.py b/python/paddle/v2/fluid/layers/device.py index 3fee263ac0fc0aa794b290c35e6a929572d83e6d..e0c1aab230aeed7fb858e91e7da7eae58032ee16 100644 --- a/python/paddle/v2/fluid/layers/device.py +++ b/python/paddle/v2/fluid/layers/device.py @@ -25,7 +25,8 @@ __all__ = ['get_places'] @autodoc() def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) - out_places = helper.create_variable(name=unique_name(helper.name + ".out")) + out_places = helper.create_variable( + name=unique_name.generate(helper.name + ".out")) attrs = dict() if device_count is not None: attrs['device_count'] = int(device_count) diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 417a01b76f16336d38a3f7589f660b1a7779594e..beebc1a85f88511822e7f8ad4cd62fc024318430 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -21,7 +21,7 @@ __all__ = ['monkey_patch_variable'] def monkey_patch_variable(): def unique_tmp_name(): - return unique_name("tmp") + return unique_name.generate("tmp") def safe_get_dtype(var): try: diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index ecc42f6215bdd13f6ea4284dcd67b6026ad33129..61febc4e383b5eb0e75c0005330b92fa90ddbe44 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -17,7 +17,8 @@ from collections import defaultdict import framework import layers from backward import append_backward -from framework import unique_name, program_guard +from framework import program_guard +import unique_name from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops @@ -49,7 +50,7 @@ class Optimizer(object): def _create_global_learning_rate(self): if isinstance(self._global_learning_rate, float): self._global_learning_rate = layers.create_global_var( - name=unique_name("learning_rate"), + name=unique_name.generate("learning_rate"), shape=[1], value=float(self._global_learning_rate), dtype='float32', @@ -118,7 +119,7 @@ class Optimizer(object): assert isinstance(self.helper, LayerHelper) var = self.helper.create_global_variable( - name=unique_name(name), + name=unique_name.generate(name), persistable=True, dtype=dtype or param.dtype, type=param.type, @@ -379,7 +380,7 @@ class AdamOptimizer(Optimizer): # Create beta1 and beta2 power tensors beta_shape = [1] self._beta1_pow_acc = self.helper.create_global_variable( - name=unique_name('beta1_pow_acc'), + name=unique_name.generate('beta1_pow_acc'), dtype='float32', shape=beta_shape, lod_level=0, @@ -388,7 +389,7 @@ class AdamOptimizer(Optimizer): self._beta1_pow_acc, initializer=Constant(self._beta1)) self._beta2_pow_acc = self.helper.create_global_variable( - name=unique_name('beta2_pow_acc'), + name=unique_name.generate('beta2_pow_acc'), dtype='float32', shape=beta_shape, lod_level=0, @@ -481,7 +482,7 @@ class AdamaxOptimizer(Optimizer): # Create beta1 power accumulator tensor beta_shape = [1] self._beta1_pow_acc = self.helper.create_global_variable( - name=unique_name('beta1_pow_acc'), + name=unique_name.generate('beta1_pow_acc'), dtype='float32', shape=beta_shape, lod_level=0, diff --git a/python/paddle/v2/fluid/tests/unittests/test_unique_name.py b/python/paddle/v2/fluid/tests/unittests/test_unique_name.py new file mode 100644 index 0000000000000000000000000000000000000000..e28810c96b8bc829b7034914b429ce21a59260bb --- /dev/null +++ b/python/paddle/v2/fluid/tests/unittests/test_unique_name.py @@ -0,0 +1,43 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.v2.fluid as fluid + + +class TestUniqueName(unittest.TestCase): + def test_guard(self): + with fluid.unique_name.guard(): + name_1 = fluid.unique_name.generate('') + + with fluid.unique_name.guard(): + name_2 = fluid.unique_name.generate('') + + self.assertEqual(name_1, name_2) + + with fluid.unique_name.guard("A"): + name_1 = fluid.unique_name.generate('') + + with fluid.unique_name.guard('B'): + name_2 = fluid.unique_name.generate('') + + self.assertNotEqual(name_1, name_2) + + def test_generate(self): + with fluid.unique_name.guard(): + name1 = fluid.unique_name.generate('fc') + name2 = fluid.unique_name.generate('fc') + name3 = fluid.unique_name.generate('tmp') + self.assertNotEqual(name1, name2) + self.assertEqual(name1[-2:], name3[-2:]) diff --git a/python/paddle/v2/fluid/unique_name.py b/python/paddle/v2/fluid/unique_name.py new file mode 100644 index 0000000000000000000000000000000000000000..33c53113ae7e8ed9aeada31f2aed6990b6fea110 --- /dev/null +++ b/python/paddle/v2/fluid/unique_name.py @@ -0,0 +1,74 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import contextlib +import sys + +__all__ = ['generate', 'switch', 'guard', 'UniqueNameGenerator'] + + +class UniqueNameGenerator(object): + """ + Generate unique name with prefix. + + Args: + prefix(str): The generated name prefix. All generated name will be + started with this prefix. + """ + + def __init__(self, prefix=None): + self.ids = collections.defaultdict(int) + if prefix is None: + prefix = "" + self.prefix = prefix + + def __call__(self, key): + """ + Generate unique names with prefix + + Args: + key(str): The key of return string. + + Returns(str): A unique string with the prefix + """ + tmp = self.ids[key] + self.ids[key] += 1 + return self.prefix + "_".join([key, str(tmp)]) + + +generator = UniqueNameGenerator() + + +def generate(key): + return generator(key) + + +def switch(new_generator=None): + global generator + old = generator + if new_generator is None: + generator = UniqueNameGenerator() + else: + generator = new_generator + return old + + +@contextlib.contextmanager +def guard(new_generator=None): + if isinstance(new_generator, basestring): + new_generator = UniqueNameGenerator(new_generator) + old = switch(new_generator) + yield + switch(old)