diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 13798c88b1856a6ca0ed3cdb1c4a7bf8d82f4a09..7a0b566ea87c6a3cf2dfbd0b082c53f65eeb042a 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto filename = Attr("file_path"); - auto overwrite = Attr("overwrite"); - - if (FileExists(filename) && !overwrite) { - PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", - filename, overwrite); - } - - MkDirRecursively(DirName(filename).c_str()); auto iname = Input("X"); auto *var = scope.FindVar(iname); @@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase { iname); if (var->IsType()) { - SaveLodTensor(filename, place, var); + SaveLodTensor(place, var); } else if (var->IsType()) { SaveSelectedRows(scope, place, var); } else { @@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase { } } - void SaveLodTensor(const std::string &filename, const platform::Place &place, + void SaveLodTensor( const platform::Place &place, framework::Variable *var) const { + auto filename = Attr("file_path"); + auto overwrite = Attr("overwrite"); + + if (FileExists(filename) && !overwrite) { + PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", + filename, overwrite); + } + + MkDirRecursively(DirName(filename).c_str()); + auto &tensor = var->Get(); // get device context from pool diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index ce82b6b904b0aebaea950db9422e61a2cc603d5c..ffe0021e96c3c09e0eba93cb38ee872c80bb1ae5 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -503,7 +503,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) - save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table) + save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) _scroll_delete(checkpoint_dir, max_num_checkpoints) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f9c39262ce32ee97ea1ce5e14420f5d67ad5b120..a1617600d6260f50e8d47f7e8eeb4f30b0bda953 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -846,7 +846,7 @@ class DistributeTranspiler: inputs={'X': [self.table_name]}, outputs={}, attrs={ - 'file_path': self.table_name) + 'file_path': self.table_name }) return checkpoint_save_block.idx