未验证 提交 0b6f09e7 编写于 作者: C Chen Weihang 提交者: GitHub

Op (Save/LoadCombine) error message enhancement (#23647)

* op save/load_combine error msg polish, test=develop

* fix detail error, test=develop
上级 b61aaa2c
......@@ -35,21 +35,27 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
auto out_var_names = ctx.OutputNames("Out");
PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0.");
PADDLE_ENFORCE_GT(out_var_names.size(), 0UL,
platform::errors::InvalidArgument(
"The number of variables to be loaded is %d, expect "
"it to be greater than 0.",
out_var_names.size()));
if (!model_from_memory) {
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin),
"OP(LoadCombine) fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin), true,
platform::errors::Unavailable(
"LoadCombine operator fails to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
} else {
PADDLE_ENFORCE(!filename.empty(),
"OP(LoadCombine) fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename);
PADDLE_ENFORCE_NE(
filename.empty(), true,
platform::errors::Unavailable(
"LoadCombine operator fails to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
std::stringstream fin(filename, std::ios::in | std::ios::binary);
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
}
......@@ -64,16 +70,19 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
auto out_vars = context.MultiOutputVar("Out");
for (size_t i = 0; i < out_var_names.size(); i++) {
PADDLE_ENFORCE(out_vars[i] != nullptr,
"Output variable %s cannot be found", out_var_names[i]);
PADDLE_ENFORCE_NOT_NULL(
out_vars[i], platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.",
out_var_names[i]));
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
// Error checking
PADDLE_ENFORCE(
static_cast<bool>(*buffer),
"There is a problem with loading model parameters. "
"Please check whether the model file is complete or damaged.");
PADDLE_ENFORCE_EQ(
static_cast<bool>(*buffer), true,
platform::errors::Unavailable(
"An error occurred while loading model parameters. "
"Please check whether the model file is complete or damaged."));
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx);
......@@ -100,9 +109,10 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
}
}
buffer->peek();
PADDLE_ENFORCE(buffer->eof(),
"You are not allowed to load partial data via "
"load_combine_op, use load_op instead.");
PADDLE_ENFORCE_EQ(buffer->eof(), true,
platform::errors::Unavailable(
"Not allowed to load partial data via "
"load_combine_op, please use load_op instead."));
}
};
......
......@@ -41,31 +41,40 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
bool is_present = FileExists(filename);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
filename, overwrite);
PADDLE_THROW(platform::errors::PreconditionNotMet(
"%s exists! Cannot save_combine to it when overwrite is set to "
"false.",
filename, overwrite));
}
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout), true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
auto inp_var_names = ctx.InputNames("X");
auto &inp_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
PADDLE_ENFORCE_GT(inp_var_names.size(), 0UL,
platform::errors::InvalidArgument(
"The number of variables to be saved is %d, expect "
"it to be greater than 0.",
inp_var_names.size()));
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) {
PADDLE_ENFORCE(inp_vars[i] != nullptr,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(inp_vars[i]->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
PADDLE_ENFORCE_NOT_NULL(
inp_vars[i],
platform::errors::InvalidArgument("Cannot find variable %s to save.",
inp_var_names[i]));
PADDLE_ENFORCE_EQ(inp_vars[i]->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"SaveCombine operator only supports saving "
"LoDTensor variable, %s has wrong type.",
inp_var_names[i]));
auto &tensor = inp_vars[i]->Get<framework::LoDTensor>();
// Serialize tensors one by one
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册