提交 904a43b0 编写于 作者: Q QiJune 提交者: GitHub

Merge pull request #440 from reyoung/feature/fix_pydataprovider_multiple_obj_bugs

Feature/fix pydataprovider multiple obj bugs
...@@ -139,7 +139,7 @@ def define_py_data_sources(train_list, test_list, module, obj, args=None, ...@@ -139,7 +139,7 @@ def define_py_data_sources(train_list, test_list, module, obj, args=None,
test_obj = obj test_obj = obj
train_obj = obj train_obj = obj
if __is_splitable__(obj): if __is_splitable__(obj):
train_module, test_module = module train_obj, test_obj = obj
if args is None: if args is None:
args = "" args = ""
......
...@@ -13,9 +13,16 @@ img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cos ...@@ -13,9 +13,16 @@ img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cos
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops) test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops)
whole_configs=(test_split_datasource)
for conf in ${configs[*]} for conf in ${configs[*]}
do do
echo "Generating " $conf echo "Generating " $conf
python -m paddle.utils.dump_config $conf.py > $protostr/$conf.protostr.unitest python -m paddle.utils.dump_config $conf.py > $protostr/$conf.protostr.unitest
done done
for conf in ${whole_configs[*]}
do
echo "Generating " $conf
python -m paddle.utils.dump_config $conf.py "" --whole > $protostr/$conf.protostr.unitest
done
model_config {
type: "nn"
layers {
name: "a"
type: "data"
size: 10
active_type: ""
}
input_layer_names: "a"
output_layer_names: "a"
sub_models {
name: "root"
layer_names: "a"
input_layer_names: "a"
output_layer_names: "a"
is_recurrent_layer_group: false
}
}
data_config {
type: "py2"
files: "train.list"
async_load_data: true
for_test: false
load_data_module: "a"
load_data_object: "c"
load_data_args: ""
data_ratio: 1
is_main_data: true
usage_ratio: 1.0
}
opt_config {
batch_size: 1000
algorithm: "sgd"
learning_rate: 0.001
learning_rate_decay_a: 0.0
learning_rate_decay_b: 0.0
l1weight: 0.1
l2weight: 0.0
c1: 0.0001
backoff: 0.5
owlqn_steps: 10
max_backoff: 5
l2weight_zero_iter: 0
average_window: 0
learning_method: "momentum"
ada_epsilon: 1e-06
do_average_in_cpu: false
ada_rou: 0.95
learning_rate_schedule: "poly"
delta_add_rate: 1.0
shrink_parameter_value: 0
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1e-08
learning_rate_args: ""
async_lagged_grad_discard_ratio: 1.5
}
test_data_config {
type: "py2"
files: "test.list"
async_load_data: true
for_test: true
load_data_module: "b"
load_data_object: "d"
load_data_args: ""
data_ratio: 1
is_main_data: true
usage_ratio: 1.0
}
save_dir: "./output/model"
start_pass: 0
from paddle.trainer_config_helpers import *
define_py_data_sources2(train_list="train.list",
test_list="test.list",
module=["a", "b"],
obj=("c", "d"))
settings(
learning_rate=1e-3,
batch_size=1000
)
outputs(data_layer(name="a", size=10))
...@@ -19,13 +19,21 @@ import sys ...@@ -19,13 +19,21 @@ import sys
__all__ = [] __all__ = []
if __name__ == '__main__': if __name__ == '__main__':
whole_conf = False
if len(sys.argv) == 2: if len(sys.argv) == 2:
conf = parse_config(sys.argv[1], '') conf = parse_config(sys.argv[1], '')
elif len(sys.argv) == 3: elif len(sys.argv) == 3:
conf = parse_config(sys.argv[1], sys.argv[2]) conf = parse_config(sys.argv[1], sys.argv[2])
elif len(sys.argv) == 4:
conf = parse_config(sys.argv[1], sys.argv[2])
if sys.argv[3] == '--whole':
whole_conf = True
else: else:
raise RuntimeError() raise RuntimeError()
assert isinstance(conf, TrainerConfig_pb2.TrainerConfig) assert isinstance(conf, TrainerConfig_pb2.TrainerConfig)
print conf.model_config if whole_conf:
print conf
else:
print conf.model_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册