未验证 提交 0c30098f 编写于 作者: H hutuxian 提交者: GitHub

Add need_save_delta parameter to solve OOM (#23097)

上级 2e2da712
......@@ -117,8 +117,8 @@ void BoxWrapper::BeginPass() const {
"BeginPass failed in BoxPS."));
}
void BoxWrapper::EndPass() const {
int ret = boxps_ptr_->EndPass();
void BoxWrapper::EndPass(bool need_save_delta) const {
int ret = boxps_ptr_->EndPass(need_save_delta);
PADDLE_ENFORCE_EQ(
ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS."));
}
......
......@@ -129,7 +129,7 @@ class BoxWrapper {
void BeginFeedPass(int date, boxps::PSAgentBase** agent) const;
void EndFeedPass(boxps::PSAgentBase* agent) const;
void BeginPass() const;
void EndPass() const;
void EndPass(bool need_save_delta) const;
void PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
......@@ -503,10 +503,10 @@ class BoxHelper {
box_ptr->BeginPass();
#endif
}
void EndPass() {
void EndPass(bool need_save_delta) {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->EndPass();
box_ptr->EndPass(need_save_delta);
#endif
}
void LoadIntoMemory() {
......
......@@ -832,7 +832,7 @@ class BoxPSDataset(InMemoryDataset):
"""
self.boxps.begin_pass()
def end_pass(self):
def end_pass(self, need_save_delta):
"""
End Pass
Notify BoxPS that current pass ended
......@@ -841,9 +841,9 @@ class BoxPSDataset(InMemoryDataset):
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
dataset.end_pass()
dataset.end_pass(True)
"""
self.boxps.end_pass()
self.boxps.end_pass(need_save_delta)
def wait_preload_done(self):
"""
......
......@@ -150,7 +150,7 @@ class TestBoxPSPreload(unittest.TestCase):
program=fluid.default_main_program(),
dataset=datasets[0],
print_period=1)
datasets[0].end_pass()
datasets[0].end_pass(True)
datasets[1].wait_preload_done()
datasets[1].begin_pass()
exe.train_from_dataset(
......@@ -158,7 +158,7 @@ class TestBoxPSPreload(unittest.TestCase):
dataset=datasets[1],
print_period=1,
debug=True)
datasets[1].end_pass()
datasets[1].end_pass(False)
for f in filelist:
os.remove(f)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册