未验证 提交 1fe87a4d 编写于 作者: T TFLM-bot 提交者: GitHub

Automated sync from github.com/tensorflow/tensorflow (#69)

上级 83e656ad
......@@ -46,6 +46,22 @@ class OpResolver {
}
virtual ~OpResolver() {}
private:
/// Returns true if this OpResolver may contain any "user defined" ops.
/// By "user defined" ops, we mean any op definitions other than those
/// contained in tflite::ops::builtin::BuiltinOpResolver.
///
/// If this method returns true, it doesn't necessarily mean that the
/// OpResolver contains a user-defined op, just that the absence of
/// user-defined ops can't be guaranteed.
///
/// Note that "user-defined" ops are not the same as "custom" ops;
/// BuiltinOpResolver may support certain "custom" ops, in addition to
/// "builtin" ops, and may not support all of the "builtin" op enum values.
virtual bool MayContainUserDefinedOps() const { return true; }
friend class OpResolverInternal;
};
// Handles the logic for converting between an OperatorCode structure extracted
......
......@@ -44,66 +44,6 @@ const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
} // namespace
namespace internal {
ContextHelper::ContextHelper(ErrorReporter* error_reporter,
MicroAllocator* allocator, const Model* model)
: allocator_(allocator), error_reporter_(error_reporter), model_(model) {}
void* ContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx,
size_t bytes) {
return reinterpret_cast<ContextHelper*>(ctx->impl_)
->allocator_->AllocatePersistentBuffer(bytes);
}
TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
size_t bytes,
int* buffer_idx) {
ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
return helper->allocator_->RequestScratchBufferInArena(bytes, buffer_idx);
}
void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
ScratchBufferHandle* handle = helper->scratch_buffer_handles_ + buffer_idx;
return handle->data;
}
void ContextHelper::ReportOpError(struct TfLiteContext* context,
const char* format, ...) {
#ifndef TF_LITE_STRIP_ERROR_STRINGS
ContextHelper* helper = static_cast<ContextHelper*>(context->impl_);
va_list args;
va_start(args, format);
TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args);
va_end(args);
#endif
}
TfLiteTensor* ContextHelper::GetTensor(const struct TfLiteContext* context,
int tensor_idx) {
ContextHelper* helper = static_cast<ContextHelper*>(context->impl_);
return helper->allocator_->AllocateTempTfLiteTensor(
helper->model_, helper->eval_tensors_, tensor_idx);
}
TfLiteEvalTensor* ContextHelper::GetEvalTensor(
const struct TfLiteContext* context, int tensor_idx) {
ContextHelper* helper = reinterpret_cast<ContextHelper*>(context->impl_);
return &helper->eval_tensors_[tensor_idx];
}
void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) {
eval_tensors_ = eval_tensors;
}
void ContextHelper::SetScratchBufferHandles(
ScratchBufferHandle* scratch_buffer_handles) {
scratch_buffer_handles_ = scratch_buffer_handles;
}
} // namespace internal
MicroInterpreter::MicroInterpreter(const Model* model,
const MicroOpResolver& op_resolver,
uint8_t* tensor_arena,
......@@ -118,7 +58,6 @@ MicroInterpreter::MicroInterpreter(const Model* model,
tensors_allocated_(false),
initialization_status_(kTfLiteError),
eval_tensors_(nullptr),
context_helper_(error_reporter_, &allocator_, model),
input_tensors_(nullptr),
output_tensors_(nullptr) {
Init(profiler);
......@@ -136,7 +75,6 @@ MicroInterpreter::MicroInterpreter(const Model* model,
tensors_allocated_(false),
initialization_status_(kTfLiteError),
eval_tensors_(nullptr),
context_helper_(error_reporter_, &allocator_, model),
input_tensors_(nullptr),
output_tensors_(nullptr) {
Init(profiler);
......@@ -168,10 +106,10 @@ void MicroInterpreter::Init(MicroProfiler* profiler) {
}
subgraph_ = (*subgraphs)[0];
context_.impl_ = static_cast<void*>(&context_helper_);
context_.ReportError = context_helper_.ReportOpError;
context_.GetTensor = context_helper_.GetTensor;
context_.GetEvalTensor = context_helper_.GetEvalTensor;
context_.impl_ = static_cast<void*>(this);
context_.ReportError = ReportOpError;
context_.GetTensor = GetTensor;
context_.GetEvalTensor = GetEvalTensor;
context_.recommended_num_threads = 1;
context_.profiler = profiler;
......@@ -188,15 +126,10 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
return kTfLiteError;
}
// Update the pointer now that TfLiteEvalTensor allocation has completed on
// the context helper.
// TODO(b/16157777): This call would not be needed if ContextHelper rolled
// into the interpreter.
context_helper_.SetTfLiteEvalTensors(eval_tensors_);
context_.tensors_size = subgraph_->tensors()->size();
// Only allow AllocatePersistentBuffer in Init stage.
context_.AllocatePersistentBuffer = context_helper_.AllocatePersistentBuffer;
context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
context_.RequestScratchBufferInArena = nullptr;
context_.GetScratchBuffer = nullptr;
......@@ -220,8 +153,7 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
// Both AllocatePersistentBuffer and RequestScratchBufferInArena is
// available in Prepare stage.
context_.RequestScratchBufferInArena =
context_helper_.RequestScratchBufferInArena;
context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
auto* node = &(node_and_registrations_[i].node);
auto* registration = node_and_registrations_[i].registration;
......@@ -242,13 +174,11 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
// allowed. Kernels can only fetch scratch buffers via GetScratchBuffer.
context_.AllocatePersistentBuffer = nullptr;
context_.RequestScratchBufferInArena = nullptr;
context_.GetScratchBuffer = context_helper_.GetScratchBuffer;
context_.GetScratchBuffer = GetScratchBuffer;
TF_LITE_ENSURE_OK(&context_,
allocator_.FinishModelAllocation(model_, eval_tensors_,
&scratch_buffer_handles_));
// TODO(b/16157777): Remove this when ContextHelper is rolled into this class.
context_helper_.SetScratchBufferHandles(scratch_buffer_handles_);
// TODO(b/162311891): Drop these allocations when the interpreter supports
// handling buffers from TfLiteEvalTensor.
......@@ -406,4 +336,54 @@ TfLiteStatus MicroInterpreter::ResetVariableTensors() {
return kTfLiteOk;
}
void* MicroInterpreter::AllocatePersistentBuffer(TfLiteContext* context,
size_t bytes) {
return reinterpret_cast<MicroInterpreter*>(context->impl_)
->allocator_.AllocatePersistentBuffer(bytes);
}
TfLiteStatus MicroInterpreter::RequestScratchBufferInArena(
TfLiteContext* context, size_t bytes, int* buffer_idx) {
// All scratch buffer requests are managed in the allocator. Simply route the
// request and let the allocator manage allocations.
return static_cast<MicroInterpreter*>(context->impl_)
->allocator_.RequestScratchBufferInArena(bytes, buffer_idx);
}
void* MicroInterpreter::GetScratchBuffer(TfLiteContext* context,
int buffer_idx) {
MicroInterpreter* interpreter =
static_cast<MicroInterpreter*>(context->impl_);
ScratchBufferHandle* handle =
interpreter->scratch_buffer_handles_ + buffer_idx;
return handle->data;
}
void MicroInterpreter::ReportOpError(struct TfLiteContext* context,
const char* format, ...) {
#ifndef TF_LITE_STRIP_ERROR_STRINGS
MicroInterpreter* interpreter =
static_cast<MicroInterpreter*>(context->impl_);
va_list args;
va_start(args, format);
TF_LITE_REPORT_ERROR(interpreter->error_reporter_, format, args);
va_end(args);
#endif
}
TfLiteTensor* MicroInterpreter::GetTensor(const struct TfLiteContext* context,
int tensor_idx) {
MicroInterpreter* interpreter =
static_cast<MicroInterpreter*>(context->impl_);
return interpreter->allocator_.AllocateTempTfLiteTensor(
interpreter->model_, interpreter->eval_tensors_, tensor_idx);
}
TfLiteEvalTensor* MicroInterpreter::GetEvalTensor(
const struct TfLiteContext* context, int tensor_idx) {
MicroInterpreter* interpreter =
static_cast<MicroInterpreter*>(context->impl_);
return &interpreter->eval_tensors_[tensor_idx];
}
} // namespace tflite
......@@ -34,46 +34,6 @@ limitations under the License.
namespace tflite {
namespace internal {
// A helper class to encapsulate the implementation of APIs in Context.
// context->impl_ points to an instance of this class.
// Check tensorflow/lite/c/common.h for detailed descriptions.
// TODO(b/16157777): Consider rolling this class into MicroInterpreter.
class ContextHelper {
public:
explicit ContextHelper(ErrorReporter* error_reporter,
MicroAllocator* allocator, const Model* model);
// Functions that will be assigned to function pointers on TfLiteContext:
static void* AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes);
static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx,
size_t bytes,
int* buffer_idx);
static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx);
static void ReportOpError(struct TfLiteContext* context, const char* format,
...);
static TfLiteTensor* GetTensor(const struct TfLiteContext* context,
int tensor_idx);
static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
int tensor_idx);
// Sets the pointer to a list of TfLiteEvalTensor instances.
void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors);
// Sets the pointer to a list of ScratchBufferHandle instances.
void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles);
private:
MicroAllocator* allocator_ = nullptr;
ErrorReporter* error_reporter_ = nullptr;
const Model* model_ = nullptr;
TfLiteEvalTensor* eval_tensors_ = nullptr;
ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
};
} // namespace internal
class MicroInterpreter {
public:
// The lifetime of the model, op resolver, tensor arena, error reporter and
......@@ -181,6 +141,19 @@ class MicroInterpreter {
// error reporting during initialization.
void Init(MicroProfiler* profiler);
// Static functions that are bound to the TfLiteContext instance:
static void* AllocatePersistentBuffer(TfLiteContext* Context, size_t bytes);
static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* context,
size_t bytes,
int* buffer_idx);
static void* GetScratchBuffer(TfLiteContext* context, int buffer_idx);
static void ReportOpError(struct TfLiteContext* context, const char* format,
...);
static TfLiteTensor* GetTensor(const struct TfLiteContext* context,
int tensor_idx);
static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
int tensor_idx);
NodeAndRegistration* node_and_registrations_ = nullptr;
const Model* model_;
......@@ -196,9 +169,6 @@ class MicroInterpreter {
TfLiteEvalTensor* eval_tensors_ = nullptr;
ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
// TODO(b/16157777): Drop this reference:
internal::ContextHelper context_helper_;
// TODO(b/162311891): Clean these pointers up when this class supports buffers
// from TfLiteEvalTensor.
TfLiteTensor** input_tensors_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册