未验证 提交 37fcf03a 编写于 作者: Z zhongpu 提交者: GitHub

Op (Save/Load) error message enhancement, test=develop (#23650)

上级 c6d14bc8
...@@ -34,26 +34,29 @@ class LoadOpKernel : public framework::OpKernel<T> { ...@@ -34,26 +34,29 @@ class LoadOpKernel : public framework::OpKernel<T> {
// it to save an output stream. // it to save an output stream.
auto filename = ctx.Attr<std::string>("file_path"); auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary); std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
filename); platform::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
auto out_var_name = ctx.OutputNames("Out").data(); auto out_var_name = ctx.OutputNames("Out").data();
auto *out_var = ctx.OutputVar("Out"); auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found ", PADDLE_ENFORCE_NOT_NULL(
out_var_name); out_var,
platform::errors::InvalidArgument(
PADDLE_ENFORCE(out_var != nullptr, "Output variable cannot be found "); "The variable %s to be loaded cannot be found.", out_var_name));
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx); LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var); LoadSelectedRows(fin, place, out_var);
} else { } else {
PADDLE_ENFORCE( PADDLE_THROW(platform::errors::InvalidArgument(
false, "Load operator only supports loading LoDTensor and SelectedRows "
"Load only support LoDTensor and SelectedRows, %s has wrong type", "variable, %s has wrong type",
out_var_name); out_var_name));
} }
} }
......
...@@ -41,18 +41,19 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -41,18 +41,19 @@ class SaveOpKernel : public framework::OpKernel<T> {
auto *input_var = ctx.InputVar("X"); auto *input_var = ctx.InputVar("X");
auto iname = ctx.InputNames("X").data(); auto iname = ctx.InputNames("X").data();
PADDLE_ENFORCE(input_var != nullptr, "Cannot find variable %s for save_op", PADDLE_ENFORCE_NOT_NULL(
iname); input_var, platform::errors::InvalidArgument(
"The variable %s to be saved cannot be found.", iname));
if (input_var->IsType<framework::LoDTensor>()) { if (input_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(ctx, place, input_var); SaveLodTensor(ctx, place, input_var);
} else if (input_var->IsType<framework::SelectedRows>()) { } else if (input_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(ctx, place, input_var); SaveSelectedRows(ctx, place, input_var);
} else { } else {
PADDLE_ENFORCE( PADDLE_THROW(platform::errors::InvalidArgument(
false, "Save operator only supports saving LoDTensor and SelectedRows "
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type", "variable, %s has wrong type",
iname); iname));
} }
} }
...@@ -62,10 +63,11 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -62,10 +63,11 @@ class SaveOpKernel : public framework::OpKernel<T> {
auto filename = ctx.Attr<std::string>("file_path"); auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite"); auto overwrite = ctx.Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) { PADDLE_ENFORCE_EQ(
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", FileExists(filename) && !overwrite, false,
filename, overwrite); platform::errors::PreconditionNotMet(
} "%s exists!, cannot save to it when overwrite is set to false.",
filename, overwrite));
MkDirRecursively(DirName(filename).c_str()); MkDirRecursively(DirName(filename).c_str());
...@@ -78,8 +80,9 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -78,8 +80,9 @@ class SaveOpKernel : public framework::OpKernel<T> {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
std::ofstream fout(filename, std::ios::binary); std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
filename); platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16"); auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto in_dtype = tensor.type(); auto in_dtype = tensor.type();
...@@ -117,10 +120,11 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -117,10 +120,11 @@ class SaveOpKernel : public framework::OpKernel<T> {
} }
} }
if (FileExists(filename) && !overwrite) { PADDLE_ENFORCE_EQ(
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", FileExists(filename) && !overwrite, false,
filename, overwrite); platform::errors::PreconditionNotMet(
} "%s exists!, cannot save to it when overwrite is set to false.",
filename, overwrite));
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
...@@ -135,8 +139,9 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -135,8 +139,9 @@ class SaveOpKernel : public framework::OpKernel<T> {
// FIXME(yuyang18): We save variable to local file now, but we should change // FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream. // it to save an output stream.
std::ofstream fout(filename, std::ios::binary); std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
filename); platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
framework::SerializeToStream(fout, selectedRows, dev_ctx); framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close(); fout.close();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册