From 92923d086b7631398fa9302f33a2112e817fb45b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Dec 2019 15:58:15 -0800 Subject: [PATCH] Add builders for Landmarks2TransformMatrix layer PiperOrigin-RevId: 286283327 Change-Id: I8fd9102c04ffe56af92fa19b01bcece139c45bae --- .../delegates/gpu/common/model_builder.cc | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 4e5ee940841..e1397c6a034 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -2333,6 +2333,38 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { private: }; +class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + /*outputs=*/1); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks + RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix + + const std::string op_name = "landmarks_to_transform_matrix"; + node->operation.type = op_name; + BHWC output_shape; + RETURN_IF_ERROR( + ParseCustomAttributes(op_name, tflite_node->custom_initial_data, + tflite_node->custom_initial_data_size, + &(node->operation.attributes), &output_shape)); + + auto output_value = graph->FindOutputs(node->id)[0]; + output_value->tensor.shape = output_shape; + return OkStatus(); + } + + private: +}; + class UnsupportedOperationParser : public TFLiteOperationParser { public: Status IsSupported(const TfLiteContext* context, @@ -2450,6 +2482,10 @@ std::unique_ptr NewOperationParser( return absl::make_unique(); } + if (custom_name == "Landmarks2TransformMatrix") { + return absl::make_unique(); + } + break; } return absl::make_unique(); -- GitLab