提交 05f1d078 编写于 作者: M Mihai Maruseac

Validate `NodeDef`s from `FunctionDefLibrary` of a `GraphDef`.

We already validated `NodeDef`s from a `GraphDef` but missed validating those from the `FunctionDefLibrary`. Thus, some maliciously crafted models could evade detection and cause denial of service due to a `CHECK`-fail.

PiperOrigin-RevId: 332536309
Change-Id: I052efe919ff1fe2f90815e286a1aa4c54c7b94ff
上级 2e0cc0a3
......@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.proto.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/io/path.h"
......@@ -71,26 +72,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
// Ensure that constant tensors loaded from the saved model have valid shape.
// Also ensure that constant nodes have a value assigned to them.
// TODO(b/154763635): this is temporary and will be replaced with a better audit
static Status ValidateNode(const NodeDef& node) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
"\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
return Status::OK();
}
static Status ValidateSavedTensors(const GraphDef& graph_def) {
for (const auto& node : graph_def.node()) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"",
node.op(), "\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
TF_RETURN_IF_ERROR(ValidateNode(node));
}
if (graph_def.has_library()) {
const FunctionDefLibrary& library = graph_def.library();
for (const auto& function : library.function()) {
for (const auto& node : function.node_def()) {
TF_RETURN_IF_ERROR(ValidateNode(node));
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
}
return Status::OK();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册