diff --git a/tools/train.py b/tools/train.py index 149490123599fd194c662c90937736444c00f150..0f5e9039cfd4cc23e418434323b74c6612587ed2 100755 --- a/tools/train.py +++ b/tools/train.py @@ -92,7 +92,7 @@ def main(): 'fetch_name_list':eval_fetch_name_list,\ 'fetch_varname_list':eval_fetch_varname_list} - if contain_det: + if train_alg_type == 'det': program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) else: program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) @@ -117,6 +117,6 @@ def test_reader(): if __name__ == '__main__': - startup_program, train_program, place, config, contain_det = program.preprocess() + startup_program, train_program, place, config, train_alg_type = program.preprocess() main() # test_reader()