From 33b81648a37c282f6128548ae7eea47faf77b6d7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 11 Nov 2016 19:11:36 +0800 Subject: [PATCH] Fix bug in multple objects in define_py_sources --- python/paddle/trainer_config_helpers/data_sources.py | 2 +- .../tests/configs/generate_protostr.sh | 3 ++- .../tests/configs/test_split_datasource.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index f51140656d..283a45df30 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -139,7 +139,7 @@ def define_py_data_sources(train_list, test_list, module, obj, args=None, test_obj = obj train_obj = obj if __is_splitable__(obj): - train_module, test_module = module + train_obj, test_obj = obj if args is None: args = "" diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index 9f614e3983..cafc2142f2 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -11,7 +11,8 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight -test_bilinear_interp test_maxout test_bi_grumemory math_ops) +test_bilinear_interp test_maxout test_bi_grumemory math_ops +test_spilit_datasource) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py b/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py new file mode 100644 index 0000000000..c8dcb1bd8a --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py @@ -0,0 +1,12 @@ +from paddle.trainer_config_helpers import * + +define_py_data_sources2(train_list="train.list", + test_list="test.list", + module=["a", "b"], + obj=("c", "d")) +settings( + learning_rate=1e-3, + batch_size=1000 +) + +outputs(data_layer(name="a", size=10)) -- GitLab