提交 7c4341c5 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Ensure delegated graph's inputs/outputs match the original tflite model's order.

PiperOrigin-RevId: 306518815
Change-Id: Ice094b324d7368914e1d0feecfd7bc129e629a4b
上级 65e3fdb4
......@@ -3124,6 +3124,33 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) {
return ConvertVectorToTfLiteIntArray(ops_to_replace);
}
// Creates inputs and outputs passed by io_tensors parameters in the resulting
// graph. We force it to make sure that delegated subgraph has same order of
// inputs and outputs with the original one. When delegated model is built from
// the tflite model representation tensors are created lazily, so there is no
// guarantee that the order will match the source model tensors order.
absl::Status PrecreateIOTensors(
TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value) {
for (int i = 0; i < io_tensors->size; ++i) {
const int tensor_index = io_tensors->data[i];
if (tensor_to_value->find(tensor_index) != tensor_to_value->end()) {
return absl::AlreadyExistsError(absl::StrCat(
"Tensor with tflite index ", tensor_index, " was already created."));
}
const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
if (tflite_tensor.allocation_type == TfLiteAllocationType::kTfLiteMmapRo) {
continue;
}
Value<TensorRef<BHWC>>* value = graph->NewValue();
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
value->tensor.ref = tensor_index;
(*tensor_to_value)[tensor_index] = value;
}
return absl::OkStatus();
}
absl::Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
......@@ -3154,6 +3181,10 @@ absl::Status BuildModel(TfLiteContext* context,
tflite_nodes.push_back(i);
}
std::unordered_map<int, Value<TensorRef<BHWC>>*> tensor_to_value;
RETURN_IF_ERROR(PrecreateIOTensors(
context, graph, delegate_params->input_tensors, &tensor_to_value));
RETURN_IF_ERROR(PrecreateIOTensors(
context, graph, delegate_params->output_tensors, &tensor_to_value));
for (int i = 0; i < operations.size(); ++i) {
TfLiteNode* tflite_node;
TfLiteRegistration* registration;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册