diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index f51140656d0dcfed14c67fd6f1d60351ba5e8ab2..283a45df3084495384ea98c06e0b6cc5c793fd5e 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -139,7 +139,7 @@ def define_py_data_sources(train_list, test_list, module, obj, args=None, test_obj = obj train_obj = obj if __is_splitable__(obj): - train_module, test_module = module + train_obj, test_obj = obj if args is None: args = "" diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index 9f614e3983ffa95b6f573b3223e26da8d4aa93a5..cafc2142f25c74d54ec8a1ab937db23306da2904 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -11,7 +11,8 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight -test_bilinear_interp test_maxout test_bi_grumemory math_ops) +test_bilinear_interp test_maxout test_bi_grumemory math_ops +test_spilit_datasource) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py b/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dcb1bd8a47b9b0296c8907e5e9474deb549ec2 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py @@ -0,0 +1,12 @@ +from paddle.trainer_config_helpers import * + +define_py_data_sources2(train_list="train.list", + test_list="test.list", + module=["a", "b"], + obj=("c", "d")) +settings( + learning_rate=1e-3, + batch_size=1000 +) + +outputs(data_layer(name="a", size=10))