diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index 517296494954e74c03f6569f817334e87b4c2f84..3e1d4558b84a92c846e1c74a0b13e7805f900a47 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -117,8 +117,8 @@ void BoxWrapper::BeginPass() const { "BeginPass failed in BoxPS.")); } -void BoxWrapper::EndPass() const { - int ret = boxps_ptr_->EndPass(); +void BoxWrapper::EndPass(bool need_save_delta) const { + int ret = boxps_ptr_->EndPass(need_save_delta); PADDLE_ENFORCE_EQ( ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS.")); } diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 2f49cbe0a6aee85e91e47842f8f3017466bfd3fe..8749011e61e1f29b483a82eafd9fcf8db84559d1 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -129,7 +129,7 @@ class BoxWrapper { void BeginFeedPass(int date, boxps::PSAgentBase** agent) const; void EndFeedPass(boxps::PSAgentBase* agent) const; void BeginPass() const; - void EndPass() const; + void EndPass(bool need_save_delta) const; void PullSparse(const paddle::platform::Place& place, const std::vector& keys, const std::vector& values, @@ -503,10 +503,10 @@ class BoxHelper { box_ptr->BeginPass(); #endif } - void EndPass() { + void EndPass(bool need_save_delta) { #ifdef PADDLE_WITH_BOX_PS auto box_ptr = BoxWrapper::GetInstance(); - box_ptr->EndPass(); + box_ptr->EndPass(need_save_delta); #endif } void LoadIntoMemory() { diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index f4c17fb7858ca0e37aa7265d603e0fb019995174..97900d02cb7c83bdf22026f33deef72dbce0f013 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -832,7 +832,7 @@ class BoxPSDataset(InMemoryDataset): """ self.boxps.begin_pass() - def end_pass(self): + def end_pass(self, need_save_delta): """ End Pass Notify BoxPS that current pass ended @@ -841,9 +841,9 @@ class BoxPSDataset(InMemoryDataset): import paddle.fluid as fluid dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") - dataset.end_pass() + dataset.end_pass(True) """ - self.boxps.end_pass() + self.boxps.end_pass(need_save_delta) def wait_preload_done(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_boxps.py b/python/paddle/fluid/tests/unittests/test_boxps.py index c914abbf23d9aa9b982fb40a529480f60ac684c3..563ccc2b8af4654dbb77fb62bacdd7c72f94a5ea 100644 --- a/python/paddle/fluid/tests/unittests/test_boxps.py +++ b/python/paddle/fluid/tests/unittests/test_boxps.py @@ -150,7 +150,7 @@ class TestBoxPSPreload(unittest.TestCase): program=fluid.default_main_program(), dataset=datasets[0], print_period=1) - datasets[0].end_pass() + datasets[0].end_pass(True) datasets[1].wait_preload_done() datasets[1].begin_pass() exe.train_from_dataset( @@ -158,7 +158,7 @@ class TestBoxPSPreload(unittest.TestCase): dataset=datasets[1], print_period=1, debug=True) - datasets[1].end_pass() + datasets[1].end_pass(False) for f in filelist: os.remove(f)