From 4607d517bfba481cbb02b2beb1dfa3773eadfded Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 11 Nov 2016 19:31:28 +0800 Subject: [PATCH] Add unittest for split datasource * Fix #436 --- .../tests/configs/generate_protostr.sh | 10 ++- .../protostr/test_split_datasource.protostr | 72 +++++++++++++++++++ python/paddle/utils/dump_config.py | 10 ++- 3 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index e84e2a4b7f..bb594ac2c2 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -11,12 +11,18 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight -test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops -test_split_datasource) +test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops) +whole_configs=(test_split_datasource) for conf in ${configs[*]} do echo "Generating " $conf python -m paddle.utils.dump_config $conf.py > $protostr/$conf.protostr.unitest done + +for conf in ${whole_configs[*]} +do + echo "Generating " $conf + python -m paddle.utils.dump_config $conf.py "" --whole > $protostr/$conf.protostr.unitest +done diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr new file mode 100644 index 0000000000..1cfb92255a --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_split_datasource.protostr @@ -0,0 +1,72 @@ +model_config { + type: "nn" + layers { + name: "a" + type: "data" + size: 10 + active_type: "" + } + input_layer_names: "a" + output_layer_names: "a" + sub_models { + name: "root" + layer_names: "a" + input_layer_names: "a" + output_layer_names: "a" + is_recurrent_layer_group: false + } +} +data_config { + type: "py2" + files: "train.list" + async_load_data: true + for_test: false + load_data_module: "a" + load_data_object: "c" + load_data_args: "" + data_ratio: 1 + is_main_data: true + usage_ratio: 1.0 +} +opt_config { + batch_size: 1000 + algorithm: "sgd" + learning_rate: 0.001 + learning_rate_decay_a: 0.0 + learning_rate_decay_b: 0.0 + l1weight: 0.1 + l2weight: 0.0 + c1: 0.0001 + backoff: 0.5 + owlqn_steps: 10 + max_backoff: 5 + l2weight_zero_iter: 0 + average_window: 0 + learning_method: "momentum" + ada_epsilon: 1e-06 + do_average_in_cpu: false + ada_rou: 0.95 + learning_rate_schedule: "poly" + delta_add_rate: 1.0 + shrink_parameter_value: 0 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1e-08 + learning_rate_args: "" + async_lagged_grad_discard_ratio: 1.5 +} +test_data_config { + type: "py2" + files: "test.list" + async_load_data: true + for_test: true + load_data_module: "b" + load_data_object: "d" + load_data_args: "" + data_ratio: 1 + is_main_data: true + usage_ratio: 1.0 +} +save_dir: "./output/model" +start_pass: 0 + diff --git a/python/paddle/utils/dump_config.py b/python/paddle/utils/dump_config.py index d8a2722575..c5ce5c8d9a 100644 --- a/python/paddle/utils/dump_config.py +++ b/python/paddle/utils/dump_config.py @@ -19,13 +19,21 @@ import sys __all__ = [] if __name__ == '__main__': + whole_conf = False if len(sys.argv) == 2: conf = parse_config(sys.argv[1], '') elif len(sys.argv) == 3: conf = parse_config(sys.argv[1], sys.argv[2]) + elif len(sys.argv) == 4: + conf = parse_config(sys.argv[1], sys.argv[2]) + if sys.argv[3] == '--whole': + whole_conf = True else: raise RuntimeError() assert isinstance(conf, TrainerConfig_pb2.TrainerConfig) - print conf.model_config + if whole_conf: + print conf + else: + print conf.model_config -- GitLab