未验证 提交 0c30098f 编写于 作者: H hutuxian 提交者: GitHub

Add need_save_delta parameter to solve OOM (#23097)

上级 2e2da712
...@@ -117,8 +117,8 @@ void BoxWrapper::BeginPass() const { ...@@ -117,8 +117,8 @@ void BoxWrapper::BeginPass() const {
"BeginPass failed in BoxPS.")); "BeginPass failed in BoxPS."));
} }
void BoxWrapper::EndPass() const { void BoxWrapper::EndPass(bool need_save_delta) const {
int ret = boxps_ptr_->EndPass(); int ret = boxps_ptr_->EndPass(need_save_delta);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS.")); ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS."));
} }
......
...@@ -129,7 +129,7 @@ class BoxWrapper { ...@@ -129,7 +129,7 @@ class BoxWrapper {
void BeginFeedPass(int date, boxps::PSAgentBase** agent) const; void BeginFeedPass(int date, boxps::PSAgentBase** agent) const;
void EndFeedPass(boxps::PSAgentBase* agent) const; void EndFeedPass(boxps::PSAgentBase* agent) const;
void BeginPass() const; void BeginPass() const;
void EndPass() const; void EndPass(bool need_save_delta) const;
void PullSparse(const paddle::platform::Place& place, void PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys, const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values, const std::vector<float*>& values,
...@@ -503,10 +503,10 @@ class BoxHelper { ...@@ -503,10 +503,10 @@ class BoxHelper {
box_ptr->BeginPass(); box_ptr->BeginPass();
#endif #endif
} }
void EndPass() { void EndPass(bool need_save_delta) {
#ifdef PADDLE_WITH_BOX_PS #ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance(); auto box_ptr = BoxWrapper::GetInstance();
box_ptr->EndPass(); box_ptr->EndPass(need_save_delta);
#endif #endif
} }
void LoadIntoMemory() { void LoadIntoMemory() {
......
...@@ -832,7 +832,7 @@ class BoxPSDataset(InMemoryDataset): ...@@ -832,7 +832,7 @@ class BoxPSDataset(InMemoryDataset):
""" """
self.boxps.begin_pass() self.boxps.begin_pass()
def end_pass(self): def end_pass(self, need_save_delta):
""" """
End Pass End Pass
Notify BoxPS that current pass ended Notify BoxPS that current pass ended
...@@ -841,9 +841,9 @@ class BoxPSDataset(InMemoryDataset): ...@@ -841,9 +841,9 @@ class BoxPSDataset(InMemoryDataset):
import paddle.fluid as fluid import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
dataset.end_pass() dataset.end_pass(True)
""" """
self.boxps.end_pass() self.boxps.end_pass(need_save_delta)
def wait_preload_done(self): def wait_preload_done(self):
""" """
......
...@@ -150,7 +150,7 @@ class TestBoxPSPreload(unittest.TestCase): ...@@ -150,7 +150,7 @@ class TestBoxPSPreload(unittest.TestCase):
program=fluid.default_main_program(), program=fluid.default_main_program(),
dataset=datasets[0], dataset=datasets[0],
print_period=1) print_period=1)
datasets[0].end_pass() datasets[0].end_pass(True)
datasets[1].wait_preload_done() datasets[1].wait_preload_done()
datasets[1].begin_pass() datasets[1].begin_pass()
exe.train_from_dataset( exe.train_from_dataset(
...@@ -158,7 +158,7 @@ class TestBoxPSPreload(unittest.TestCase): ...@@ -158,7 +158,7 @@ class TestBoxPSPreload(unittest.TestCase):
dataset=datasets[1], dataset=datasets[1],
print_period=1, print_period=1,
debug=True) debug=True)
datasets[1].end_pass() datasets[1].end_pass(False)
for f in filelist: for f in filelist:
os.remove(f) os.remove(f)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册