提交 d3c76ae3 编写于 作者: X Xunkai Zhang 提交者: TensorFlower Gardener

[tfls.codegen] Fix potential nullptr seg fault.

PiperOrigin-RevId: 306544979
Change-Id: I7768511d0ca5ba226d6909852fc902cd282aadb4
上级 4b2cb677
......@@ -112,20 +112,24 @@ TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
tensor_info.normalization_unit =
FindNormalizationUnit(metadata, tensor_identifier, err);
if (metadata->content()->content_properties_type() ==
ContentProperties_ImageProperties) {
if (metadata->content()
->content_properties_as_ImageProperties()
->color_space() == ColorSpaceType_RGB) {
tensor_info.content_type = "image";
tensor_info.wrapper_type = "TensorImage";
tensor_info.processor_type = "ImageProcessor";
return tensor_info;
} else {
err->Warning(
"Found Non-RGB image on tensor (%s). Codegen currently does not "
"support it, and regard it as a plain numeric tensor.",
tensor_identifier.c_str());
if (metadata->content() != nullptr &&
metadata->content()->content_properties() != nullptr) {
// Enter tensor wrapper type inferring
if (metadata->content()->content_properties_type() ==
ContentProperties_ImageProperties) {
if (metadata->content()
->content_properties_as_ImageProperties()
->color_space() == ColorSpaceType_RGB) {
tensor_info.content_type = "image";
tensor_info.wrapper_type = "TensorImage";
tensor_info.processor_type = "ImageProcessor";
return tensor_info;
} else {
err->Warning(
"Found Non-RGB image on tensor (%s). Codegen currently does not "
"support it, and regard it as a plain numeric tensor.",
tensor_identifier.c_str());
}
}
}
tensor_info.content_type = "tensor";
......@@ -154,12 +158,12 @@ ModelInfo CreateModelInfo(const ModelMetadata* metadata,
graph->input_tensor_metadata(), graph->output_tensor_metadata());
std::vector<std::string> input_tensor_names = std::move(names.first);
std::vector<std::string> output_tensor_names = std::move(names.second);
for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) {
for (int i = 0; i < input_tensor_names.size(); i++) {
model_info.inputs.push_back(
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
input_tensor_names[i], true, i, err));
}
for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) {
for (int i = 0; i < output_tensor_names.size(); i++) {
model_info.outputs.push_back(
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
output_tensor_names[i], false, i, err));
......@@ -945,6 +949,11 @@ GenerationResult AndroidJavaGenerator::Generate(
const Model* model, const std::string& package_name,
const std::string& model_class_name, const std::string& model_asset_path) {
GenerationResult result;
if (model == nullptr) {
err_.Error(
"Cannot read model from the buffer. Codegen will generate nothing.");
return result;
}
const ModelMetadata* metadata = GetMetadataFromModel(model);
if (metadata == nullptr) {
err_.Error(
......
......@@ -24,14 +24,22 @@ namespace codegen {
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
const ModelMetadata* GetMetadataFromModel(const Model* model) {
if (model->metadata() == nullptr) {
if (model == nullptr || model->metadata() == nullptr) {
return nullptr;
}
for (auto i = 0; i < model->metadata()->size(); i++) {
if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) {
const auto* name = model->metadata()->Get(i)->name();
if (name != nullptr && name->str() == BUFFER_KEY) {
const auto buffer_index = model->metadata()->Get(i)->buffer();
const auto* buffer = model->buffers()->Get(buffer_index)->data()->data();
return GetModelMetadata(buffer);
if (model->buffers() == nullptr ||
model->buffers()->size() <= buffer_index) {
continue;
}
const auto* buffer_vec = model->buffers()->Get(buffer_index)->data();
if (buffer_vec == nullptr || buffer_vec->data() == nullptr) {
continue;
}
return GetModelMetadata(buffer_vec->data());
}
}
return nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册