提交 3cf7337f 编写于 作者: X xuwei06

Correctly handling multiple calls to parse_config()

To solve this, we maintain the list of DefaultNameFactory used in by trainer_config_helper,
and reset the state at the beginning of each parse_config call.

Change-Id: I13c7574dc8f0b6bc6f6b7c92eb425e2c52c926e8
上级 db379811
......@@ -141,9 +141,9 @@ def init_config_environment(
g_add_submodel_suffix=False,
# Whether current layer needs to pass the image height and width.
# Default value is true, but if it encounters recurrent_layer_group,
# it will be false. The reason is that image is converted to be sequence,
# image height will be sequence length, and image width will be feature
# Default value is true, but if it encounters recurrent_layer_group,
# it will be false. The reason is that image is converted to be sequence,
# image height will be sequence length, and image width will be feature
# length of each timestep.
g_pass_height_width=True, ):
......@@ -1067,7 +1067,7 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode):
return 1 + int(math.ceil(output))
#calcualte image_size based on output_size for de-convolution (ConvTransLayer).
#calcualte image_size based on output_size for de-convolution (ConvTransLayer).
#It is the reverse function of cnn_output_size
def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode):
img_size = (output_size - 1) * stride + filter_size - 2 * padding
......@@ -3364,6 +3364,14 @@ def my_fatal(s):
logger.critical(s)
raise Exception()
_parse_config_hooks = set()
def register_parse_config_hook(f):
"""
Register a hook function for parse_config. parse_config will invoke the hook
at the beginning of parse. This make it possible to reset global state for
for constructing the model.
"""
_parse_config_hooks.add(f)
def parse_config(config_file, config_arg_str):
'''
......@@ -3371,6 +3379,8 @@ def parse_config(config_file, config_arg_str):
passed to config script as a dictionary CONFIG_ARGS
'''
init_config_environment()
for hook in _parse_config_hooks:
hook()
config_args = {}
......
......@@ -78,6 +78,17 @@ class DefaultNameFactory(object):
"""
pass
def reset(self):
self.__counter__ = 0
_name_factories = []
def reset_hook():
for factory in _name_factories:
factory.reset()
register_parse_config_hook(reset_hook)
def wrap_name_default(name_prefix=None):
"""
......@@ -95,7 +106,9 @@ def wrap_name_default(name_prefix=None):
:return: a decorator to set default name
:rtype: callable
"""
return wrap_param_default(["name"], DefaultNameFactory(name_prefix))
factory = DefaultNameFactory(name_prefix)
_name_factories.append(factory)
return wrap_param_default(["name"], factory)
def wrap_param_attr_default(param_names=None, default_factory=None):
......
......@@ -4,6 +4,11 @@ add_test(NAME layers_test
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
add_test(NAME test_reset_hook
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/test_rest_hook.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
if (PROTOBUF_3)
add_paddle_exe(protobuf_equal
ProtobufEqualMain.cpp)
......
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.trainer.config_parser import parse_config
class TestParse(unittest.TestCase):
def test_parse(self):
a = parse_config(
'trainer_config_helpers/tests/layers_test_config.py', '')
b = parse_config(
'trainer_config_helpers/tests/layers_test_config.py', '')
self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册