diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c6c0c9c151d840963fab1fe689eb5b9c340518ce..699fe6630a37f9f8e3b7f1e4ec043e2da0574c4a 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -141,9 +141,9 @@ def init_config_environment( g_add_submodel_suffix=False, # Whether current layer needs to pass the image height and width. - # Default value is true, but if it encounters recurrent_layer_group, - # it will be false. The reason is that image is converted to be sequence, - # image height will be sequence length, and image width will be feature + # Default value is true, but if it encounters recurrent_layer_group, + # it will be false. The reason is that image is converted to be sequence, + # image height will be sequence length, and image width will be feature # length of each timestep. g_pass_height_width=True, ): @@ -1067,7 +1067,7 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode): return 1 + int(math.ceil(output)) -#calcualte image_size based on output_size for de-convolution (ConvTransLayer). +#calcualte image_size based on output_size for de-convolution (ConvTransLayer). #It is the reverse function of cnn_output_size def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode): img_size = (output_size - 1) * stride + filter_size - 2 * padding @@ -3364,6 +3364,14 @@ def my_fatal(s): logger.critical(s) raise Exception() +_parse_config_hooks = set() +def register_parse_config_hook(f): + """ + Register a hook function for parse_config. parse_config will invoke the hook + at the beginning of parse. This make it possible to reset global state for + for constructing the model. + """ + _parse_config_hooks.add(f) def parse_config(config_file, config_arg_str): ''' @@ -3371,6 +3379,8 @@ def parse_config(config_file, config_arg_str): passed to config script as a dictionary CONFIG_ARGS ''' init_config_environment() + for hook in _parse_config_hooks: + hook() config_args = {} diff --git a/python/paddle/trainer_config_helpers/default_decorators.py b/python/paddle/trainer_config_helpers/default_decorators.py index c01050e338d5933f49f0504f2e9ef5f15c7743ba..23a4fa241d074e6e7a1f3b420f3f81cac20b1e8a 100644 --- a/python/paddle/trainer_config_helpers/default_decorators.py +++ b/python/paddle/trainer_config_helpers/default_decorators.py @@ -78,6 +78,17 @@ class DefaultNameFactory(object): """ pass + def reset(self): + self.__counter__ = 0 + + +_name_factories = [] + +def reset_hook(): + for factory in _name_factories: + factory.reset() + +register_parse_config_hook(reset_hook) def wrap_name_default(name_prefix=None): """ @@ -95,7 +106,9 @@ def wrap_name_default(name_prefix=None): :return: a decorator to set default name :rtype: callable """ - return wrap_param_default(["name"], DefaultNameFactory(name_prefix)) + factory = DefaultNameFactory(name_prefix) + _name_factories.append(factory) + return wrap_param_default(["name"], factory) def wrap_param_attr_default(param_names=None, default_factory=None): diff --git a/python/paddle/trainer_config_helpers/tests/CMakeLists.txt b/python/paddle/trainer_config_helpers/tests/CMakeLists.txt index 6180b2efbcad87e511a4b981d533f204f45fb5dc..bff82f75050d1c30c53658fe341f77da864d3ae8 100644 --- a/python/paddle/trainer_config_helpers/tests/CMakeLists.txt +++ b/python/paddle/trainer_config_helpers/tests/CMakeLists.txt @@ -4,6 +4,11 @@ add_test(NAME layers_test python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) +add_test(NAME test_reset_hook + COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/ + python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/test_rest_hook.py + WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) + if (PROTOBUF_3) add_paddle_exe(protobuf_equal ProtobufEqualMain.cpp) diff --git a/python/paddle/trainer_config_helpers/tests/test_reset_hook.py b/python/paddle/trainer_config_helpers/tests/test_reset_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..dc494d0eef22c927fe8afea5af2f8c36ff405173 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/test_reset_hook.py @@ -0,0 +1,28 @@ +# Copyright PaddlePaddle contributors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from paddle.trainer.config_parser import parse_config + +class TestParse(unittest.TestCase): + + def test_parse(self): + a = parse_config( + 'trainer_config_helpers/tests/layers_test_config.py', '') + b = parse_config( + 'trainer_config_helpers/tests/layers_test_config.py', '') + self.assertEqual(a, b) + + +if __name__ == '__main__': + unittest.main()