未验证 提交 e4b9fcdb 编写于 作者: D Dun 提交者: GitHub

More restrict check load_combine_op. (#15479)

* fix && test=develop

* fix && test=develop

* test=develop
上级 48a5cccb
...@@ -64,7 +64,7 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -64,7 +64,7 @@ class LoadCombineOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
// Error checking // Error checking
PADDLE_ENFORCE(static_cast<bool>(buffer), "Cannot read more"); PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
// Get data from fin to tensor // Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx); DeserializeFromStream(*buffer, tensor, dev_ctx);
...@@ -90,6 +90,10 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -90,6 +90,10 @@ class LoadCombineOp : public framework::OperatorBase {
tensor->ShareDataWith(fp16_tensor); tensor->ShareDataWith(fp16_tensor);
} }
} }
buffer->peek();
PADDLE_ENFORCE(buffer->eof(),
"You are not allowed to load partial data via "
"load_combine_op, use load_op instead.");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册