From 1fe87a4d56988ee70658ed76c7280bf5f194cc5d Mon Sep 17 00:00:00 2001 From: TFLM-bot Date: Tue, 11 May 2021 09:12:02 -0700 Subject: [PATCH] Automated sync from github.com/tensorflow/tensorflow (#69) --- tensorflow/lite/core/api/op_resolver.h | 16 +++ tensorflow/lite/micro/micro_interpreter.cc | 134 +++++++++------------ tensorflow/lite/micro/micro_interpreter.h | 56 ++------- 3 files changed, 86 insertions(+), 120 deletions(-) diff --git a/tensorflow/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h index f43c6ba5..471db813 100644 --- a/tensorflow/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -46,6 +46,22 @@ class OpResolver { } virtual ~OpResolver() {} + + private: + /// Returns true if this OpResolver may contain any "user defined" ops. + /// By "user defined" ops, we mean any op definitions other than those + /// contained in tflite::ops::builtin::BuiltinOpResolver. + /// + /// If this method returns true, it doesn't necessarily mean that the + /// OpResolver contains a user-defined op, just that the absence of + /// user-defined ops can't be guaranteed. + /// + /// Note that "user-defined" ops are not the same as "custom" ops; + /// BuiltinOpResolver may support certain "custom" ops, in addition to + /// "builtin" ops, and may not support all of the "builtin" op enum values. + virtual bool MayContainUserDefinedOps() const { return true; } + + friend class OpResolverInternal; }; // Handles the logic for converting between an OperatorCode structure extracted diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index f01ed641..3dc6611e 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -44,66 +44,6 @@ const char* OpNameFromRegistration(const TfLiteRegistration* registration) { } // namespace -namespace internal { - -ContextHelper::ContextHelper(ErrorReporter* error_reporter, - MicroAllocator* allocator, const Model* model) - : allocator_(allocator), error_reporter_(error_reporter), model_(model) {} - -void* ContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx, - size_t bytes) { - return reinterpret_cast(ctx->impl_) - ->allocator_->AllocatePersistentBuffer(bytes); -} - -TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, - size_t bytes, - int* buffer_idx) { - ContextHelper* helper = reinterpret_cast(ctx->impl_); - return helper->allocator_->RequestScratchBufferInArena(bytes, buffer_idx); -} - -void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { - ContextHelper* helper = reinterpret_cast(ctx->impl_); - ScratchBufferHandle* handle = helper->scratch_buffer_handles_ + buffer_idx; - return handle->data; -} - -void ContextHelper::ReportOpError(struct TfLiteContext* context, - const char* format, ...) { -#ifndef TF_LITE_STRIP_ERROR_STRINGS - ContextHelper* helper = static_cast(context->impl_); - va_list args; - va_start(args, format); - TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args); - va_end(args); -#endif -} - -TfLiteTensor* ContextHelper::GetTensor(const struct TfLiteContext* context, - int tensor_idx) { - ContextHelper* helper = static_cast(context->impl_); - return helper->allocator_->AllocateTempTfLiteTensor( - helper->model_, helper->eval_tensors_, tensor_idx); -} - -TfLiteEvalTensor* ContextHelper::GetEvalTensor( - const struct TfLiteContext* context, int tensor_idx) { - ContextHelper* helper = reinterpret_cast(context->impl_); - return &helper->eval_tensors_[tensor_idx]; -} - -void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) { - eval_tensors_ = eval_tensors; -} - -void ContextHelper::SetScratchBufferHandles( - ScratchBufferHandle* scratch_buffer_handles) { - scratch_buffer_handles_ = scratch_buffer_handles; -} - -} // namespace internal - MicroInterpreter::MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver, uint8_t* tensor_arena, @@ -118,7 +58,6 @@ MicroInterpreter::MicroInterpreter(const Model* model, tensors_allocated_(false), initialization_status_(kTfLiteError), eval_tensors_(nullptr), - context_helper_(error_reporter_, &allocator_, model), input_tensors_(nullptr), output_tensors_(nullptr) { Init(profiler); @@ -136,7 +75,6 @@ MicroInterpreter::MicroInterpreter(const Model* model, tensors_allocated_(false), initialization_status_(kTfLiteError), eval_tensors_(nullptr), - context_helper_(error_reporter_, &allocator_, model), input_tensors_(nullptr), output_tensors_(nullptr) { Init(profiler); @@ -168,10 +106,10 @@ void MicroInterpreter::Init(MicroProfiler* profiler) { } subgraph_ = (*subgraphs)[0]; - context_.impl_ = static_cast(&context_helper_); - context_.ReportError = context_helper_.ReportOpError; - context_.GetTensor = context_helper_.GetTensor; - context_.GetEvalTensor = context_helper_.GetEvalTensor; + context_.impl_ = static_cast(this); + context_.ReportError = ReportOpError; + context_.GetTensor = GetTensor; + context_.GetEvalTensor = GetEvalTensor; context_.recommended_num_threads = 1; context_.profiler = profiler; @@ -188,15 +126,10 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { return kTfLiteError; } - // Update the pointer now that TfLiteEvalTensor allocation has completed on - // the context helper. - // TODO(b/16157777): This call would not be needed if ContextHelper rolled - // into the interpreter. - context_helper_.SetTfLiteEvalTensors(eval_tensors_); context_.tensors_size = subgraph_->tensors()->size(); // Only allow AllocatePersistentBuffer in Init stage. - context_.AllocatePersistentBuffer = context_helper_.AllocatePersistentBuffer; + context_.AllocatePersistentBuffer = AllocatePersistentBuffer; context_.RequestScratchBufferInArena = nullptr; context_.GetScratchBuffer = nullptr; @@ -220,8 +153,7 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { // Both AllocatePersistentBuffer and RequestScratchBufferInArena is // available in Prepare stage. - context_.RequestScratchBufferInArena = - context_helper_.RequestScratchBufferInArena; + context_.RequestScratchBufferInArena = RequestScratchBufferInArena; for (size_t i = 0; i < subgraph_->operators()->size(); ++i) { auto* node = &(node_and_registrations_[i].node); auto* registration = node_and_registrations_[i].registration; @@ -242,13 +174,11 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { // allowed. Kernels can only fetch scratch buffers via GetScratchBuffer. context_.AllocatePersistentBuffer = nullptr; context_.RequestScratchBufferInArena = nullptr; - context_.GetScratchBuffer = context_helper_.GetScratchBuffer; + context_.GetScratchBuffer = GetScratchBuffer; TF_LITE_ENSURE_OK(&context_, allocator_.FinishModelAllocation(model_, eval_tensors_, &scratch_buffer_handles_)); - // TODO(b/16157777): Remove this when ContextHelper is rolled into this class. - context_helper_.SetScratchBufferHandles(scratch_buffer_handles_); // TODO(b/162311891): Drop these allocations when the interpreter supports // handling buffers from TfLiteEvalTensor. @@ -406,4 +336,54 @@ TfLiteStatus MicroInterpreter::ResetVariableTensors() { return kTfLiteOk; } +void* MicroInterpreter::AllocatePersistentBuffer(TfLiteContext* context, + size_t bytes) { + return reinterpret_cast(context->impl_) + ->allocator_.AllocatePersistentBuffer(bytes); +} + +TfLiteStatus MicroInterpreter::RequestScratchBufferInArena( + TfLiteContext* context, size_t bytes, int* buffer_idx) { + // All scratch buffer requests are managed in the allocator. Simply route the + // request and let the allocator manage allocations. + return static_cast(context->impl_) + ->allocator_.RequestScratchBufferInArena(bytes, buffer_idx); +} + +void* MicroInterpreter::GetScratchBuffer(TfLiteContext* context, + int buffer_idx) { + MicroInterpreter* interpreter = + static_cast(context->impl_); + ScratchBufferHandle* handle = + interpreter->scratch_buffer_handles_ + buffer_idx; + return handle->data; +} + +void MicroInterpreter::ReportOpError(struct TfLiteContext* context, + const char* format, ...) { +#ifndef TF_LITE_STRIP_ERROR_STRINGS + MicroInterpreter* interpreter = + static_cast(context->impl_); + va_list args; + va_start(args, format); + TF_LITE_REPORT_ERROR(interpreter->error_reporter_, format, args); + va_end(args); +#endif +} + +TfLiteTensor* MicroInterpreter::GetTensor(const struct TfLiteContext* context, + int tensor_idx) { + MicroInterpreter* interpreter = + static_cast(context->impl_); + return interpreter->allocator_.AllocateTempTfLiteTensor( + interpreter->model_, interpreter->eval_tensors_, tensor_idx); +} + +TfLiteEvalTensor* MicroInterpreter::GetEvalTensor( + const struct TfLiteContext* context, int tensor_idx) { + MicroInterpreter* interpreter = + static_cast(context->impl_); + return &interpreter->eval_tensors_[tensor_idx]; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 39fb09b2..7da4c0b8 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -34,46 +34,6 @@ limitations under the License. namespace tflite { -namespace internal { - -// A helper class to encapsulate the implementation of APIs in Context. -// context->impl_ points to an instance of this class. -// Check tensorflow/lite/c/common.h for detailed descriptions. -// TODO(b/16157777): Consider rolling this class into MicroInterpreter. -class ContextHelper { - public: - explicit ContextHelper(ErrorReporter* error_reporter, - MicroAllocator* allocator, const Model* model); - - // Functions that will be assigned to function pointers on TfLiteContext: - static void* AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes); - static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, - size_t bytes, - int* buffer_idx); - static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); - static void ReportOpError(struct TfLiteContext* context, const char* format, - ...); - static TfLiteTensor* GetTensor(const struct TfLiteContext* context, - int tensor_idx); - static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context, - int tensor_idx); - - // Sets the pointer to a list of TfLiteEvalTensor instances. - void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors); - - // Sets the pointer to a list of ScratchBufferHandle instances. - void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles); - - private: - MicroAllocator* allocator_ = nullptr; - ErrorReporter* error_reporter_ = nullptr; - const Model* model_ = nullptr; - TfLiteEvalTensor* eval_tensors_ = nullptr; - ScratchBufferHandle* scratch_buffer_handles_ = nullptr; -}; - -} // namespace internal - class MicroInterpreter { public: // The lifetime of the model, op resolver, tensor arena, error reporter and @@ -181,6 +141,19 @@ class MicroInterpreter { // error reporting during initialization. void Init(MicroProfiler* profiler); + // Static functions that are bound to the TfLiteContext instance: + static void* AllocatePersistentBuffer(TfLiteContext* Context, size_t bytes); + static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* context, + size_t bytes, + int* buffer_idx); + static void* GetScratchBuffer(TfLiteContext* context, int buffer_idx); + static void ReportOpError(struct TfLiteContext* context, const char* format, + ...); + static TfLiteTensor* GetTensor(const struct TfLiteContext* context, + int tensor_idx); + static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context, + int tensor_idx); + NodeAndRegistration* node_and_registrations_ = nullptr; const Model* model_; @@ -196,9 +169,6 @@ class MicroInterpreter { TfLiteEvalTensor* eval_tensors_ = nullptr; ScratchBufferHandle* scratch_buffer_handles_ = nullptr; - // TODO(b/16157777): Drop this reference: - internal::ContextHelper context_helper_; - // TODO(b/162311891): Clean these pointers up when this class supports buffers // from TfLiteEvalTensor. TfLiteTensor** input_tensors_; -- GitLab