未验证 提交 37fcf03a 编写于 作者: Z zhongpu 提交者: GitHub

Op (Save/Load) error message enhancement, test=develop (#23650)

上级 c6d14bc8
......@@ -34,26 +34,29 @@ class LoadOpKernel : public framework::OpKernel<T> {
// it to save an output stream.
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin), true,
platform::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
auto out_var_name = ctx.OutputNames("Out").data();
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found ",
out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable cannot be found ");
PADDLE_ENFORCE_NOT_NULL(
out_var,
platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.", out_var_name));
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
PADDLE_THROW(platform::errors::InvalidArgument(
"Load operator only supports loading LoDTensor and SelectedRows "
"variable, %s has wrong type",
out_var_name));
}
}
......
......@@ -41,18 +41,19 @@ class SaveOpKernel : public framework::OpKernel<T> {
auto *input_var = ctx.InputVar("X");
auto iname = ctx.InputNames("X").data();
PADDLE_ENFORCE(input_var != nullptr, "Cannot find variable %s for save_op",
iname);
PADDLE_ENFORCE_NOT_NULL(
input_var, platform::errors::InvalidArgument(
"The variable %s to be saved cannot be found.", iname));
if (input_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(ctx, place, input_var);
} else if (input_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(ctx, place, input_var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
PADDLE_THROW(platform::errors::InvalidArgument(
"Save operator only supports saving LoDTensor and SelectedRows "
"variable, %s has wrong type",
iname));
}
}
......@@ -62,10 +63,11 @@ class SaveOpKernel : public framework::OpKernel<T> {
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
PADDLE_ENFORCE_EQ(
FileExists(filename) && !overwrite, false,
platform::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
filename, overwrite));
MkDirRecursively(DirName(filename).c_str());
......@@ -78,8 +80,9 @@ class SaveOpKernel : public framework::OpKernel<T> {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto in_dtype = tensor.type();
......@@ -117,10 +120,11 @@ class SaveOpKernel : public framework::OpKernel<T> {
}
}
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
PADDLE_ENFORCE_EQ(
FileExists(filename) && !overwrite, false,
platform::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
filename, overwrite));
VLOG(4) << "SaveSelectedRows get File name: " << filename;
......@@ -135,8 +139,9 @@ class SaveOpKernel : public framework::OpKernel<T> {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册