diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index f51140656d0dcfed14c67fd6f1d60351ba5e8ab2..283a45df3084495384ea98c06e0b6cc5c793fd5e 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -139,7 +139,7 @@ def define_py_data_sources(train_list, test_list, module, obj, args=None, test_obj = obj train_obj = obj if __is_splitable__(obj): - train_module, test_module = module + train_obj, test_obj = obj if args is None: args = "" diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index c5d60443175272c5eed959a0e9d4b6027c055d94..bb594ac2c245d8882569ba2c3cf00623a8fa8e2c 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -13,9 +13,16 @@ img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cos test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops) +whole_configs=(test_split_datasource) for conf in ${configs[*]} do echo "Generating " $conf python -m paddle.utils.dump_config $conf.py > $protostr/$conf.protostr.unitest done + +for conf in ${whole_configs[*]} +do + echo "Generating " $conf + python -m paddle.utils.dump_config $conf.py "" --whole > $protostr/$conf.protostr.unitest +done diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr new file mode 100644 index 0000000000000000000000000000000000000000..1cfb92255aa92fa3fbc16a816851a5c2f81c2b56 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr @@ -0,0 +1,72 @@ +model_config { + type: "nn" + layers { + name: "a" + type: "data" + size: 10 + active_type: "" + } + input_layer_names: "a" + output_layer_names: "a" + sub_models { + name: "root" + layer_names: "a" + input_layer_names: "a" + output_layer_names: "a" + is_recurrent_layer_group: false + } +} +data_config { + type: "py2" + files: "train.list" + async_load_data: true + for_test: false + load_data_module: "a" + load_data_object: "c" + load_data_args: "" + data_ratio: 1 + is_main_data: true + usage_ratio: 1.0 +} +opt_config { + batch_size: 1000 + algorithm: "sgd" + learning_rate: 0.001 + learning_rate_decay_a: 0.0 + learning_rate_decay_b: 0.0 + l1weight: 0.1 + l2weight: 0.0 + c1: 0.0001 + backoff: 0.5 + owlqn_steps: 10 + max_backoff: 5 + l2weight_zero_iter: 0 + average_window: 0 + learning_method: "momentum" + ada_epsilon: 1e-06 + do_average_in_cpu: false + ada_rou: 0.95 + learning_rate_schedule: "poly" + delta_add_rate: 1.0 + shrink_parameter_value: 0 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1e-08 + learning_rate_args: "" + async_lagged_grad_discard_ratio: 1.5 +} +test_data_config { + type: "py2" + files: "test.list" + async_load_data: true + for_test: true + load_data_module: "b" + load_data_object: "d" + load_data_args: "" + data_ratio: 1 + is_main_data: true + usage_ratio: 1.0 +} +save_dir: "./output/model" +start_pass: 0 + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py b/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dcb1bd8a47b9b0296c8907e5e9474deb549ec2 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_split_datasource.py @@ -0,0 +1,12 @@ +from paddle.trainer_config_helpers import * + +define_py_data_sources2(train_list="train.list", + test_list="test.list", + module=["a", "b"], + obj=("c", "d")) +settings( + learning_rate=1e-3, + batch_size=1000 +) + +outputs(data_layer(name="a", size=10)) diff --git a/python/paddle/utils/dump_config.py b/python/paddle/utils/dump_config.py index d8a2722575d539447fbed90d055df71863b4ba01..c5ce5c8d9a084d68b250d091808f528459f46921 100644 --- a/python/paddle/utils/dump_config.py +++ b/python/paddle/utils/dump_config.py @@ -19,13 +19,21 @@ import sys __all__ = [] if __name__ == '__main__': + whole_conf = False if len(sys.argv) == 2: conf = parse_config(sys.argv[1], '') elif len(sys.argv) == 3: conf = parse_config(sys.argv[1], sys.argv[2]) + elif len(sys.argv) == 4: + conf = parse_config(sys.argv[1], sys.argv[2]) + if sys.argv[3] == '--whole': + whole_conf = True else: raise RuntimeError() assert isinstance(conf, TrainerConfig_pb2.TrainerConfig) - print conf.model_config + if whole_conf: + print conf + else: + print conf.model_config