diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc index ca9d2091397a27dd61a5ec4c4c4b12e4f7f24784..e2e9565f60c7e60770ebfaedc4b7c9d7dd8a0134 100755 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -32,7 +32,7 @@ namespace feed { _inference_model_base_done_path = fs->path_join(_inference_model_path, config["inference_base_done_name"].as("xbox_base_done.txt")); _inference_model_delta_done_path = fs->path_join(_inference_model_path, - config["inference_delta_done_name"].as("xbox_delta_done.txt")); + config["inference_delta_done_name"].as("xbox_patch_done.txt")); return 0; } @@ -149,7 +149,7 @@ namespace feed { if (is_last_epoch(epoch_id)) { return true; } - return delta_id(epoch_id) % 24 == 0; + return delta_id(epoch_id) % 78 == 0; case ModelSaveWay::ModelSaveTrainCheckpointBase: return is_last_epoch(epoch_id); } diff --git a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc index 1ccfad1784d35803416daa2e6dd779c25c3ebe10..ae607ccb2d899af71cee57b61b67ff77f1b657a4 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc @@ -79,7 +79,7 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, CHECK(accessor_ptr->collect_persistables_name(_persistables) == 0) << "collect_persistables Failed, class:" << accessor_class; } - // std::sort(_persistables.begin(), _persistables.end()); // 持久化变量名一定要排序 + std::sort(_persistables.begin(), _persistables.end()); // 持久化变量名一定要排序 // Monitor组件 for (const auto& monitor_config : _model_config["monitor"]) { diff --git a/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py b/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py index 728dbf1cf98341798068c8c94831563a6b0ec20a..d0f54fd75b1a19804a310f35282ca25ee62444fe 100644 --- a/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py +++ b/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py @@ -131,6 +131,11 @@ class ModelBuilder: main_program=test_program, model_filename='inference_program', program_only=True) + with open(os.path.join(self._save_path, 'inference_program'), "rb") as f: + program_desc_str = f.read() + infer_program = fluid.Program.parse_from_string(program_desc_str) + with open(os.path.join(self._save_path, 'inference_program.pbtxt'), 'w') as fout: + fout.write(str(infer_program)) params = filter(fluid.io.is_parameter, main_program.list_vars()) vars = []