Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
e71d3a5e
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e71d3a5e
编写于
4月 29, 2023
作者:
H
Hongkun Yu
提交者:
A. Unique TensorFlower
4月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove very expensive unit test.
PiperOrigin-RevId: 528210978
上级
f3922744
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
0 addition
and
192 deletion
+0
-192
official/vision/serving/export_tflite_lib_test.py
official/vision/serving/export_tflite_lib_test.py
+0
-192
未找到文件。
official/vision/serving/export_tflite_lib_test.py
已删除
100644 → 0
浏览文件 @
f3922744
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_tflite_lib."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.core
import
exp_factory
from
official.vision
import
registry_imports
# pylint: disable=unused-import
from
official.vision.dataloaders
import
tfexample_utils
from
official.vision.serving
import
detection
as
detection_serving
from
official.vision.serving
import
export_tflite_lib
from
official.vision.serving
import
image_classification
as
image_classification_serving
from
official.vision.serving
import
semantic_segmentation
as
semantic_segmentation_serving
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
# Create test data for image classification.
self
.
test_tfrecord_file_cls
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'cls_test.tfrecord'
)
example
=
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
224
,
image_width
=
224
))
self
.
_create_test_tfrecord
(
tfrecord_file
=
self
.
test_tfrecord_file_cls
,
example
=
example
,
num_samples
=
10
)
# Create test data for object detection.
self
.
test_tfrecord_file_det
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'det_test.tfrecord'
)
example
=
tfexample_utils
.
create_detection_test_example
(
image_height
=
128
,
image_width
=
128
,
image_channel
=
3
,
num_instances
=
10
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
self
.
test_tfrecord_file_det
,
example
=
example
,
num_samples
=
10
)
# Create test data for semantic segmentation.
self
.
test_tfrecord_file_seg
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'seg_test.tfrecord'
)
example
=
tfexample_utils
.
create_segmentation_test_example
(
image_height
=
512
,
image_width
=
512
,
image_channel
=
3
)
self
.
_create_test_tfrecord
(
tfrecord_file
=
self
.
test_tfrecord_file_seg
,
example
=
example
,
num_samples
=
10
)
def
_create_test_tfrecord
(
self
,
tfrecord_file
,
example
,
num_samples
):
examples
=
[
example
]
*
num_samples
tfexample_utils
.
dump_to_tfrecord
(
record_file
=
tfrecord_file
,
tf_examples
=
examples
)
def
_export_from_module
(
self
,
module
,
input_type
,
saved_model_dir
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
saved_model_dir
,
signatures
=
signatures
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'mobilenet_imagenet'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8_fallback'
,
'int8_full'
,
'int8_full_fp32_io'
,
'int8_full_int8_io'
,
]))
def
test_export_tflite_image_classification
(
self
,
experiment
,
quant_type
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
test_tfrecord_file_cls
params
.
task
.
train_data
.
input_path
=
self
.
test_tfrecord_file_cls
params
.
task
.
train_data
.
shuffle_buffer_size
=
10
temp_dir
=
self
.
get_temp_dir
()
module
=
image_classification_serving
.
ClassificationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
[
224
,
224
],
input_type
=
'tflite'
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
5
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'retinanet_mobile_coco'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8_fallback'
,
'int8_full'
,
'int8_full_fp32_io'
,
'int8_full_int8_io'
,
]))
def
test_export_tflite_detection
(
self
,
experiment
,
quant_type
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
test_tfrecord_file_det
params
.
task
.
train_data
.
input_path
=
self
.
test_tfrecord_file_det
params
.
task
.
model
.
num_classes
=
2
params
.
task
.
model
.
backbone
.
spinenet_mobile
.
model_id
=
'49XS'
params
.
task
.
model
.
input_size
=
[
128
,
128
,
3
]
params
.
task
.
model
.
detection_generator
.
nms_version
=
'v1'
params
.
task
.
train_data
.
shuffle_buffer_size
=
5
temp_dir
=
self
.
get_temp_dir
()
module
=
detection_serving
.
DetectionModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
[
128
,
128
],
input_type
=
'tflite'
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
1
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'mnv2_deeplabv3_pascal'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8_fallback'
,
'int8_full'
,
'int8_full_fp32_io'
,
'int8_full_int8_io'
,
]))
def
test_export_tflite_semantic_segmentation
(
self
,
experiment
,
quant_type
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
test_tfrecord_file_seg
params
.
task
.
train_data
.
input_path
=
self
.
test_tfrecord_file_seg
params
.
task
.
train_data
.
shuffle_buffer_size
=
10
temp_dir
=
self
.
get_temp_dir
()
module
=
semantic_segmentation_serving
.
SegmentationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
[
512
,
512
],
input_type
=
'tflite'
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
5
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录