提交 620698e7 编写于 作者: T tangwei12

bug fux

上级 1296d96e
...@@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase { ...@@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto iname = Input("X"); auto iname = Input("X");
auto *var = scope.FindVar(iname); auto *var = scope.FindVar(iname);
...@@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase {
iname); iname);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, var); SaveLodTensor(place, var);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(scope, place, var); SaveSelectedRows(scope, place, var);
} else { } else {
...@@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase { ...@@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase {
} }
} }
void SaveLodTensor(const std::string &filename, const platform::Place &place, void SaveLodTensor( const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
......
...@@ -503,7 +503,7 @@ def save_checkpoint(executor, ...@@ -503,7 +503,7 @@ def save_checkpoint(executor,
if trainer_id == 0: if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table) save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
......
...@@ -846,7 +846,7 @@ class DistributeTranspiler: ...@@ -846,7 +846,7 @@ class DistributeTranspiler:
inputs={'X': [self.table_name]}, inputs={'X': [self.table_name]},
outputs={}, outputs={},
attrs={ attrs={
'file_path': self.table_name) 'file_path': self.table_name
}) })
return checkpoint_save_block.idx return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册