From 620698e7e6f37188ba5bbd6851933a558c97f10b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jun 2018 09:41:15 +0800 Subject: [PATCH] bug fux --- paddle/fluid/operators/save_op.cc | 23 ++++++++++--------- python/paddle/fluid/io.py | 2 +- .../fluid/transpiler/distribute_transpiler.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 13798c88b1..7a0b566ea8 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto filename = Attr("file_path"); - auto overwrite = Attr("overwrite"); - - if (FileExists(filename) && !overwrite) { - PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", - filename, overwrite); - } - - MkDirRecursively(DirName(filename).c_str()); auto iname = Input("X"); auto *var = scope.FindVar(iname); @@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase { iname); if (var->IsType()) { - SaveLodTensor(filename, place, var); + SaveLodTensor(place, var); } else if (var->IsType()) { SaveSelectedRows(scope, place, var); } else { @@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase { } } - void SaveLodTensor(const std::string &filename, const platform::Place &place, + void SaveLodTensor( const platform::Place &place, framework::Variable *var) const { + auto filename = Attr("file_path"); + auto overwrite = Attr("overwrite"); + + if (FileExists(filename) && !overwrite) { + PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", + filename, overwrite); + } + + MkDirRecursively(DirName(filename).c_str()); + auto &tensor = var->Get(); // get device context from pool diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index ce82b6b904..ffe0021e96 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -503,7 +503,7 @@ def save_checkpoint(executor, if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) - save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table) + save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) _scroll_delete(checkpoint_dir, max_num_checkpoints) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f9c39262ce..a1617600d6 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -846,7 +846,7 @@ class DistributeTranspiler: inputs={'X': [self.table_name]}, outputs={}, attrs={ - 'file_path': self.table_name) + 'file_path': self.table_name }) return checkpoint_save_block.idx -- GitLab