diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 839e3ed4d6e04b13f69e6c2cfc463e83aef130f7..0a1cec8421448223c2147d53635bbb22f8758e36 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -209,12 +209,20 @@ class RunnerBase(object): if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[ "is_infer"] == False: - if context["fleet_mode"].upper() == "PS": - train_prog = context["model"][model_dict["name"]][ - "main_program"] + if context["is_fleet"]: + if context["fleet_mode"].upper() == "PS": + train_prog = context["model"][model_dict[ + "name"]]["main_program"] + print("condition 1 of bath id:{}".format( + batch_id)) + else: + train_prog = context["model"][model_dict[ + "name"]]["default_main_program"] + print("condition 2") else: train_prog = context["model"][model_dict["name"]][ "default_main_program"] + print("condition 3") startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog):