diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 13dfe45bb292e4509fdb20176ea4154e26157571..698ff22997891bc31eae628fe8eaf2ce71238544 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -99,7 +99,8 @@ static int64_t GetTimestamp() { void ListenAndServOp::RunSyncLoop( framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, - const std::vector &prefetch_block_id_list) const { + const std::vector &prefetch_block_id_list, + const int checkpoint_point_block_id) const { size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); @@ -107,7 +108,8 @@ void ListenAndServOp::RunSyncLoop( std::vector optimize_block_id_list; for (int blkid = 1; blkid < num_blocks; ++blkid) { if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(), - blkid) == prefetch_block_id_list.end()) { + blkid) == prefetch_block_id_list.end() && + blkid != checkpoint_point_block_id) { optimize_block_id_list.push_back(blkid); } } diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index b00ad195e9e162a4c911af64dd19bc7cc0ef3775..ca2dafb737a826cd1ac98987a2bf6b2cc910a03e 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -48,7 +48,8 @@ class ListenAndServOp : public framework::OperatorBase { void RunSyncLoop(framework::Executor* executor, framework::ProgramDesc* program, framework::Scope* recv_scope, - const std::vector& prefetch_block_id_list) const; + const std::vector& prefetch_block_id_list, + const int checkpoint_point_block_id) const; void RunAsyncLoop(framework::Executor* executor, framework::ProgramDesc* program) const;