提交 2f4c039e 编写于 作者: T tangwei12

rename, modify ckpt structure

上级 461d2fc0
......@@ -68,19 +68,16 @@ class CheckpointSaveOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite");
bool is_present = FileExists(filename);
bool is_present = FileExists(dir);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
filename, overwrite);
dir, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
MkDirRecursively(dir.c_str());
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
......@@ -92,6 +89,10 @@ class CheckpointSaveOp : public framework::OperatorBase {
for (size_t i = 0; i < inp_var_names.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);
std::string var_file;
var_file.append(dir);
var_file.append("/");
var_file.append(inp_var_names[i]);
PADDLE_ENFORCE(var != nullptr,
"Cannot find variable %s for save_combine_op",
......@@ -103,23 +104,10 @@ class CheckpointSaveOp : public framework::OperatorBase {
auto &tensor = var->Get<framework::LoDTensor>();
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = framework::ToDataType(tensor.type());
auto out_dtype = in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
std::ofstream fout(var_file);
framework::SerializeToStream(fout, tensor, dev_ctx);
fout.close();
}
fout.close();
}
};
......
......@@ -38,7 +38,7 @@ TEST(CheckpointSaveOp, CPU) {
}
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", std::string("tensor.save")});
attrs.insert({"dir", std::string("tensor/ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
......
......@@ -207,6 +207,11 @@ class DistributeTranspiler:
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
# is_chief (no.0 triner) for checkpoint
# the no.0 trainer will save all variables and its own reader offset to checkpoint
# other trianers will save its own reader offset to checkpoint
self.is_chief = trainer_id == 0
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
......@@ -309,6 +314,13 @@ class DistributeTranspiler:
"epmap": eplist,
"sync_mode": self.sync_mode
})
program.global_block().append_op(
type="checkpoint_save",
inputs={"X": send_outputs},
attrs={"overwrite": True,
"file_path": "/workspace/ckpt/"})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册