提交 b6aca330 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #764 from emailweixu/multiple_parse

Correctly handling multiple calls to parse_config()
...@@ -3364,6 +3364,14 @@ def my_fatal(s): ...@@ -3364,6 +3364,14 @@ def my_fatal(s):
logger.critical(s) logger.critical(s)
raise Exception() raise Exception()
_parse_config_hooks = set()
def register_parse_config_hook(f):
"""
Register a hook function for parse_config. parse_config will invoke the hook
at the beginning of parse. This make it possible to reset global state for
for constructing the model.
"""
_parse_config_hooks.add(f)
def parse_config(config_file, config_arg_str): def parse_config(config_file, config_arg_str):
''' '''
...@@ -3371,6 +3379,8 @@ def parse_config(config_file, config_arg_str): ...@@ -3371,6 +3379,8 @@ def parse_config(config_file, config_arg_str):
passed to config script as a dictionary CONFIG_ARGS passed to config script as a dictionary CONFIG_ARGS
''' '''
init_config_environment() init_config_environment()
for hook in _parse_config_hooks:
hook()
config_args = {} config_args = {}
......
...@@ -78,6 +78,17 @@ class DefaultNameFactory(object): ...@@ -78,6 +78,17 @@ class DefaultNameFactory(object):
""" """
pass pass
def reset(self):
self.__counter__ = 0
_name_factories = []
def reset_hook():
for factory in _name_factories:
factory.reset()
register_parse_config_hook(reset_hook)
def wrap_name_default(name_prefix=None): def wrap_name_default(name_prefix=None):
""" """
...@@ -95,7 +106,9 @@ def wrap_name_default(name_prefix=None): ...@@ -95,7 +106,9 @@ def wrap_name_default(name_prefix=None):
:return: a decorator to set default name :return: a decorator to set default name
:rtype: callable :rtype: callable
""" """
return wrap_param_default(["name"], DefaultNameFactory(name_prefix)) factory = DefaultNameFactory(name_prefix)
_name_factories.append(factory)
return wrap_param_default(["name"], factory)
def wrap_param_attr_default(param_names=None, default_factory=None): def wrap_param_attr_default(param_names=None, default_factory=None):
......
...@@ -4,6 +4,11 @@ add_test(NAME layers_test ...@@ -4,6 +4,11 @@ add_test(NAME layers_test
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle) WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
add_test(NAME test_reset_hook
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/test_reset_hook.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
if (PROTOBUF_3) if (PROTOBUF_3)
add_paddle_exe(protobuf_equal add_paddle_exe(protobuf_equal
ProtobufEqualMain.cpp) ProtobufEqualMain.cpp)
......
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.trainer.config_parser import parse_config
class TestParse(unittest.TestCase):
def test_parse(self):
a = parse_config(
'trainer_config_helpers/tests/layers_test_config.py', '')
b = parse_config(
'trainer_config_helpers/tests/layers_test_config.py', '')
self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册