提交 620698e7 编写于 作者: T tangwei12

bug fux

上级 1296d96e
......@@ -69,15 +69,6 @@ class SaveOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto iname = Input("X");
auto *var = scope.FindVar(iname);
......@@ -85,7 +76,7 @@ class SaveOp : public framework::OperatorBase {
iname);
if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, var);
SaveLodTensor(place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(scope, place, var);
} else {
......@@ -96,8 +87,18 @@ class SaveOp : public framework::OperatorBase {
}
}
void SaveLodTensor(const std::string &filename, const platform::Place &place,
void SaveLodTensor( const platform::Place &place,
framework::Variable *var) const {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
......
......@@ -503,7 +503,7 @@ def save_checkpoint(executor,
if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table)
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
......
......@@ -846,7 +846,7 @@ class DistributeTranspiler:
inputs={'X': [self.table_name]},
outputs={},
attrs={
'file_path': self.table_name)
'file_path': self.table_name
})
return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册