From 2f4c039e6218c68f6047c6ef8f1ba23431689e68 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 14 May 2018 21:36:34 +0800 Subject: [PATCH] rename, modify ckpt structure --- paddle/fluid/operators/checkpoint_save_op.cc | 34 ++++++------------- ..._op_test.cc => checkpoint_save_op_test.cc} | 2 +- .../fluid/transpiler/distribute_transpiler.py | 12 +++++++ 3 files changed, 24 insertions(+), 24 deletions(-) rename paddle/fluid/operators/{che'ck'po'in't_save_op_test.cc => checkpoint_save_op_test.cc} (96%) diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 2462ec09d6..1e621a00e5 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -68,19 +68,16 @@ class CheckpointSaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto filename = Attr("file_path"); + auto dir = Attr("dir"); auto overwrite = Attr("overwrite"); - bool is_present = FileExists(filename); + bool is_present = FileExists(dir); if (is_present && !overwrite) { PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false", - filename, overwrite); + dir, overwrite); } - MkDirRecursively(DirName(filename).c_str()); - std::ofstream fout(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", - filename); + MkDirRecursively(dir.c_str()); auto inp_var_names = Inputs("X"); PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, @@ -92,6 +89,10 @@ class CheckpointSaveOp : public framework::OperatorBase { for (size_t i = 0; i < inp_var_names.size(); i++) { auto *var = scope.FindVar(inp_var_names[i]); + std::string var_file; + var_file.append(dir); + var_file.append("/"); + var_file.append(inp_var_names[i]); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_combine_op", @@ -103,23 +104,10 @@ class CheckpointSaveOp : public framework::OperatorBase { auto &tensor = var->Get(); // Serialize tensors one by one - // Check types to see if a fp16 transformation is required - auto in_dtype = framework::ToDataType(tensor.type()); - auto out_dtype = in_dtype; - - if (in_dtype != out_dtype) { - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); - framework::LoDTensor out; - // copy LoD info to the new tensor - out.set_lod(tensor.lod()); - framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); - framework::SerializeToStream(fout, out, dev_ctx); - } else { - framework::SerializeToStream(fout, tensor, dev_ctx); - } + std::ofstream fout(var_file); + framework::SerializeToStream(fout, tensor, dev_ctx); + fout.close(); } - fout.close(); } }; diff --git a/paddle/fluid/operators/che'ck'po'in't_save_op_test.cc b/paddle/fluid/operators/checkpoint_save_op_test.cc similarity index 96% rename from paddle/fluid/operators/che'ck'po'in't_save_op_test.cc rename to paddle/fluid/operators/checkpoint_save_op_test.cc index b49bbd1a58..7b5aa7bcde 100644 --- a/paddle/fluid/operators/che'ck'po'in't_save_op_test.cc +++ b/paddle/fluid/operators/checkpoint_save_op_test.cc @@ -38,7 +38,7 @@ TEST(CheckpointSaveOp, CPU) { } paddle::framework::AttributeMap attrs; - attrs.insert({"file_path", std::string("tensor.save")}); + attrs.insert({"dir", std::string("tensor/ckpt")}); auto save_op = paddle::framework::OpRegistry::CreateOp( "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs); diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b45cb987d8..b76f8de504 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -207,6 +207,11 @@ class DistributeTranspiler: self.pserver_endpoints = pserver_endpoints self.optimize_ops, params_grads = self._get_optimize_pass() + # is_chief (no.0 triner) for checkpoint + # the no.0 trainer will save all variables and its own reader offset to checkpoint + # other trianers will save its own reader offset to checkpoint + self.is_chief = trainer_id == 0 + # process lookup_table_op # 1. check all lookup_table_op is distributed # 2. check all lookup_table_op share the same table. @@ -309,6 +314,13 @@ class DistributeTranspiler: "epmap": eplist, "sync_mode": self.sync_mode }) + + program.global_block().append_op( + type="checkpoint_save", + inputs={"X": send_outputs}, + attrs={"overwrite": True, + "file_path": "/workspace/ckpt/"}) + # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: -- GitLab