Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
码匠许师傅
Tflite Micro
提交
1fe87a4d
T
Tflite Micro
项目概览
码匠许师傅
/
Tflite Micro
12 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tflite Micro
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
1fe87a4d
编写于
5月 11, 2021
作者:
T
TFLM-bot
提交者:
GitHub
5月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated sync from github.com/tensorflow/tensorflow (#69)
上级
83e656ad
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
86 addition
and
120 deletion
+86
-120
tensorflow/lite/core/api/op_resolver.h
tensorflow/lite/core/api/op_resolver.h
+16
-0
tensorflow/lite/micro/micro_interpreter.cc
tensorflow/lite/micro/micro_interpreter.cc
+57
-77
tensorflow/lite/micro/micro_interpreter.h
tensorflow/lite/micro/micro_interpreter.h
+13
-43
未找到文件。
tensorflow/lite/core/api/op_resolver.h
浏览文件 @
1fe87a4d
...
...
@@ -46,6 +46,22 @@ class OpResolver {
}
virtual
~
OpResolver
()
{}
private:
/// Returns true if this OpResolver may contain any "user defined" ops.
/// By "user defined" ops, we mean any op definitions other than those
/// contained in tflite::ops::builtin::BuiltinOpResolver.
///
/// If this method returns true, it doesn't necessarily mean that the
/// OpResolver contains a user-defined op, just that the absence of
/// user-defined ops can't be guaranteed.
///
/// Note that "user-defined" ops are not the same as "custom" ops;
/// BuiltinOpResolver may support certain "custom" ops, in addition to
/// "builtin" ops, and may not support all of the "builtin" op enum values.
virtual
bool
MayContainUserDefinedOps
()
const
{
return
true
;
}
friend
class
OpResolverInternal
;
};
// Handles the logic for converting between an OperatorCode structure extracted
...
...
tensorflow/lite/micro/micro_interpreter.cc
浏览文件 @
1fe87a4d
...
...
@@ -44,66 +44,6 @@ const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
}
// namespace
namespace
internal
{
ContextHelper
::
ContextHelper
(
ErrorReporter
*
error_reporter
,
MicroAllocator
*
allocator
,
const
Model
*
model
)
:
allocator_
(
allocator
),
error_reporter_
(
error_reporter
),
model_
(
model
)
{}
void
*
ContextHelper
::
AllocatePersistentBuffer
(
TfLiteContext
*
ctx
,
size_t
bytes
)
{
return
reinterpret_cast
<
ContextHelper
*>
(
ctx
->
impl_
)
->
allocator_
->
AllocatePersistentBuffer
(
bytes
);
}
TfLiteStatus
ContextHelper
::
RequestScratchBufferInArena
(
TfLiteContext
*
ctx
,
size_t
bytes
,
int
*
buffer_idx
)
{
ContextHelper
*
helper
=
reinterpret_cast
<
ContextHelper
*>
(
ctx
->
impl_
);
return
helper
->
allocator_
->
RequestScratchBufferInArena
(
bytes
,
buffer_idx
);
}
void
*
ContextHelper
::
GetScratchBuffer
(
TfLiteContext
*
ctx
,
int
buffer_idx
)
{
ContextHelper
*
helper
=
reinterpret_cast
<
ContextHelper
*>
(
ctx
->
impl_
);
ScratchBufferHandle
*
handle
=
helper
->
scratch_buffer_handles_
+
buffer_idx
;
return
handle
->
data
;
}
void
ContextHelper
::
ReportOpError
(
struct
TfLiteContext
*
context
,
const
char
*
format
,
...)
{
#ifndef TF_LITE_STRIP_ERROR_STRINGS
ContextHelper
*
helper
=
static_cast
<
ContextHelper
*>
(
context
->
impl_
);
va_list
args
;
va_start
(
args
,
format
);
TF_LITE_REPORT_ERROR
(
helper
->
error_reporter_
,
format
,
args
);
va_end
(
args
);
#endif
}
TfLiteTensor
*
ContextHelper
::
GetTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
)
{
ContextHelper
*
helper
=
static_cast
<
ContextHelper
*>
(
context
->
impl_
);
return
helper
->
allocator_
->
AllocateTempTfLiteTensor
(
helper
->
model_
,
helper
->
eval_tensors_
,
tensor_idx
);
}
TfLiteEvalTensor
*
ContextHelper
::
GetEvalTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
)
{
ContextHelper
*
helper
=
reinterpret_cast
<
ContextHelper
*>
(
context
->
impl_
);
return
&
helper
->
eval_tensors_
[
tensor_idx
];
}
void
ContextHelper
::
SetTfLiteEvalTensors
(
TfLiteEvalTensor
*
eval_tensors
)
{
eval_tensors_
=
eval_tensors
;
}
void
ContextHelper
::
SetScratchBufferHandles
(
ScratchBufferHandle
*
scratch_buffer_handles
)
{
scratch_buffer_handles_
=
scratch_buffer_handles
;
}
}
// namespace internal
MicroInterpreter
::
MicroInterpreter
(
const
Model
*
model
,
const
MicroOpResolver
&
op_resolver
,
uint8_t
*
tensor_arena
,
...
...
@@ -118,7 +58,6 @@ MicroInterpreter::MicroInterpreter(const Model* model,
tensors_allocated_
(
false
),
initialization_status_
(
kTfLiteError
),
eval_tensors_
(
nullptr
),
context_helper_
(
error_reporter_
,
&
allocator_
,
model
),
input_tensors_
(
nullptr
),
output_tensors_
(
nullptr
)
{
Init
(
profiler
);
...
...
@@ -136,7 +75,6 @@ MicroInterpreter::MicroInterpreter(const Model* model,
tensors_allocated_
(
false
),
initialization_status_
(
kTfLiteError
),
eval_tensors_
(
nullptr
),
context_helper_
(
error_reporter_
,
&
allocator_
,
model
),
input_tensors_
(
nullptr
),
output_tensors_
(
nullptr
)
{
Init
(
profiler
);
...
...
@@ -168,10 +106,10 @@ void MicroInterpreter::Init(MicroProfiler* profiler) {
}
subgraph_
=
(
*
subgraphs
)[
0
];
context_
.
impl_
=
static_cast
<
void
*>
(
&
context_helper_
);
context_
.
ReportError
=
context_helper_
.
ReportOpError
;
context_
.
GetTensor
=
context_helper_
.
GetTensor
;
context_
.
GetEvalTensor
=
context_helper_
.
GetEvalTensor
;
context_
.
impl_
=
static_cast
<
void
*>
(
this
);
context_
.
ReportError
=
ReportOpError
;
context_
.
GetTensor
=
GetTensor
;
context_
.
GetEvalTensor
=
GetEvalTensor
;
context_
.
recommended_num_threads
=
1
;
context_
.
profiler
=
profiler
;
...
...
@@ -188,15 +126,10 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
return
kTfLiteError
;
}
// Update the pointer now that TfLiteEvalTensor allocation has completed on
// the context helper.
// TODO(b/16157777): This call would not be needed if ContextHelper rolled
// into the interpreter.
context_helper_
.
SetTfLiteEvalTensors
(
eval_tensors_
);
context_
.
tensors_size
=
subgraph_
->
tensors
()
->
size
();
// Only allow AllocatePersistentBuffer in Init stage.
context_
.
AllocatePersistentBuffer
=
context_helper_
.
AllocatePersistentBuffer
;
context_
.
AllocatePersistentBuffer
=
AllocatePersistentBuffer
;
context_
.
RequestScratchBufferInArena
=
nullptr
;
context_
.
GetScratchBuffer
=
nullptr
;
...
...
@@ -220,8 +153,7 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
// Both AllocatePersistentBuffer and RequestScratchBufferInArena is
// available in Prepare stage.
context_
.
RequestScratchBufferInArena
=
context_helper_
.
RequestScratchBufferInArena
;
context_
.
RequestScratchBufferInArena
=
RequestScratchBufferInArena
;
for
(
size_t
i
=
0
;
i
<
subgraph_
->
operators
()
->
size
();
++
i
)
{
auto
*
node
=
&
(
node_and_registrations_
[
i
].
node
);
auto
*
registration
=
node_and_registrations_
[
i
].
registration
;
...
...
@@ -242,13 +174,11 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
// allowed. Kernels can only fetch scratch buffers via GetScratchBuffer.
context_
.
AllocatePersistentBuffer
=
nullptr
;
context_
.
RequestScratchBufferInArena
=
nullptr
;
context_
.
GetScratchBuffer
=
context_helper_
.
GetScratchBuffer
;
context_
.
GetScratchBuffer
=
GetScratchBuffer
;
TF_LITE_ENSURE_OK
(
&
context_
,
allocator_
.
FinishModelAllocation
(
model_
,
eval_tensors_
,
&
scratch_buffer_handles_
));
// TODO(b/16157777): Remove this when ContextHelper is rolled into this class.
context_helper_
.
SetScratchBufferHandles
(
scratch_buffer_handles_
);
// TODO(b/162311891): Drop these allocations when the interpreter supports
// handling buffers from TfLiteEvalTensor.
...
...
@@ -406,4 +336,54 @@ TfLiteStatus MicroInterpreter::ResetVariableTensors() {
return
kTfLiteOk
;
}
void
*
MicroInterpreter
::
AllocatePersistentBuffer
(
TfLiteContext
*
context
,
size_t
bytes
)
{
return
reinterpret_cast
<
MicroInterpreter
*>
(
context
->
impl_
)
->
allocator_
.
AllocatePersistentBuffer
(
bytes
);
}
TfLiteStatus
MicroInterpreter
::
RequestScratchBufferInArena
(
TfLiteContext
*
context
,
size_t
bytes
,
int
*
buffer_idx
)
{
// All scratch buffer requests are managed in the allocator. Simply route the
// request and let the allocator manage allocations.
return
static_cast
<
MicroInterpreter
*>
(
context
->
impl_
)
->
allocator_
.
RequestScratchBufferInArena
(
bytes
,
buffer_idx
);
}
void
*
MicroInterpreter
::
GetScratchBuffer
(
TfLiteContext
*
context
,
int
buffer_idx
)
{
MicroInterpreter
*
interpreter
=
static_cast
<
MicroInterpreter
*>
(
context
->
impl_
);
ScratchBufferHandle
*
handle
=
interpreter
->
scratch_buffer_handles_
+
buffer_idx
;
return
handle
->
data
;
}
void
MicroInterpreter
::
ReportOpError
(
struct
TfLiteContext
*
context
,
const
char
*
format
,
...)
{
#ifndef TF_LITE_STRIP_ERROR_STRINGS
MicroInterpreter
*
interpreter
=
static_cast
<
MicroInterpreter
*>
(
context
->
impl_
);
va_list
args
;
va_start
(
args
,
format
);
TF_LITE_REPORT_ERROR
(
interpreter
->
error_reporter_
,
format
,
args
);
va_end
(
args
);
#endif
}
TfLiteTensor
*
MicroInterpreter
::
GetTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
)
{
MicroInterpreter
*
interpreter
=
static_cast
<
MicroInterpreter
*>
(
context
->
impl_
);
return
interpreter
->
allocator_
.
AllocateTempTfLiteTensor
(
interpreter
->
model_
,
interpreter
->
eval_tensors_
,
tensor_idx
);
}
TfLiteEvalTensor
*
MicroInterpreter
::
GetEvalTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
)
{
MicroInterpreter
*
interpreter
=
static_cast
<
MicroInterpreter
*>
(
context
->
impl_
);
return
&
interpreter
->
eval_tensors_
[
tensor_idx
];
}
}
// namespace tflite
tensorflow/lite/micro/micro_interpreter.h
浏览文件 @
1fe87a4d
...
...
@@ -34,46 +34,6 @@ limitations under the License.
namespace
tflite
{
namespace
internal
{
// A helper class to encapsulate the implementation of APIs in Context.
// context->impl_ points to an instance of this class.
// Check tensorflow/lite/c/common.h for detailed descriptions.
// TODO(b/16157777): Consider rolling this class into MicroInterpreter.
class
ContextHelper
{
public:
explicit
ContextHelper
(
ErrorReporter
*
error_reporter
,
MicroAllocator
*
allocator
,
const
Model
*
model
);
// Functions that will be assigned to function pointers on TfLiteContext:
static
void
*
AllocatePersistentBuffer
(
TfLiteContext
*
ctx
,
size_t
bytes
);
static
TfLiteStatus
RequestScratchBufferInArena
(
TfLiteContext
*
ctx
,
size_t
bytes
,
int
*
buffer_idx
);
static
void
*
GetScratchBuffer
(
TfLiteContext
*
ctx
,
int
buffer_idx
);
static
void
ReportOpError
(
struct
TfLiteContext
*
context
,
const
char
*
format
,
...);
static
TfLiteTensor
*
GetTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
);
static
TfLiteEvalTensor
*
GetEvalTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
);
// Sets the pointer to a list of TfLiteEvalTensor instances.
void
SetTfLiteEvalTensors
(
TfLiteEvalTensor
*
eval_tensors
);
// Sets the pointer to a list of ScratchBufferHandle instances.
void
SetScratchBufferHandles
(
ScratchBufferHandle
*
scratch_buffer_handles
);
private:
MicroAllocator
*
allocator_
=
nullptr
;
ErrorReporter
*
error_reporter_
=
nullptr
;
const
Model
*
model_
=
nullptr
;
TfLiteEvalTensor
*
eval_tensors_
=
nullptr
;
ScratchBufferHandle
*
scratch_buffer_handles_
=
nullptr
;
};
}
// namespace internal
class
MicroInterpreter
{
public:
// The lifetime of the model, op resolver, tensor arena, error reporter and
...
...
@@ -181,6 +141,19 @@ class MicroInterpreter {
// error reporting during initialization.
void
Init
(
MicroProfiler
*
profiler
);
// Static functions that are bound to the TfLiteContext instance:
static
void
*
AllocatePersistentBuffer
(
TfLiteContext
*
Context
,
size_t
bytes
);
static
TfLiteStatus
RequestScratchBufferInArena
(
TfLiteContext
*
context
,
size_t
bytes
,
int
*
buffer_idx
);
static
void
*
GetScratchBuffer
(
TfLiteContext
*
context
,
int
buffer_idx
);
static
void
ReportOpError
(
struct
TfLiteContext
*
context
,
const
char
*
format
,
...);
static
TfLiteTensor
*
GetTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
);
static
TfLiteEvalTensor
*
GetEvalTensor
(
const
struct
TfLiteContext
*
context
,
int
tensor_idx
);
NodeAndRegistration
*
node_and_registrations_
=
nullptr
;
const
Model
*
model_
;
...
...
@@ -196,9 +169,6 @@ class MicroInterpreter {
TfLiteEvalTensor
*
eval_tensors_
=
nullptr
;
ScratchBufferHandle
*
scratch_buffer_handles_
=
nullptr
;
// TODO(b/16157777): Drop this reference:
internal
::
ContextHelper
context_helper_
;
// TODO(b/162311891): Clean these pointers up when this class supports buffers
// from TfLiteEvalTensor.
TfLiteTensor
**
input_tensors_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录