Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2ce050bb
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2ce050bb
编写于
12月 20, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(lite): add warnning to TensorBatchCollector
GitOrigin-RevId: ba45e6a5a48a2ea3c5a0554f6b4e63665954150d
上级
ce119ef5
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
9 deletion
+24
-9
lite/pylite/megenginelite/utils.py
lite/pylite/megenginelite/utils.py
+24
-9
未找到文件。
lite/pylite/megenginelite/utils.py
浏览文件 @
2ce050bb
...
...
@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
threading
import
warnings
import
numpy
as
np
...
...
@@ -51,15 +52,24 @@ class TensorBatchCollector:
)
def
collect_id
(
self
,
array
,
batch_id
):
# get the batch index
with
self
.
_mutex
:
if
batch_id
in
self
.
_free_list
:
self
.
_free_list
.
remove
(
batch_id
)
else
:
warnings
.
warn
(
"batch {} has been collected, please call free before collected it again."
.
format
(
batch_id
)
)
self
.
_collect_with_id
(
array
,
batch_id
)
def
_collect_with_id
(
self
,
array
,
batch_id
):
if
isinstance
(
array
,
np
.
ndarray
):
shape
=
array
.
shape
assert
list
(
shape
)
==
self
.
shape
[
1
:]
in_dtype
=
ctype_to_lite_dtypes
[
np
.
ctypeslib
.
as_ctypes_type
(
array
.
dtype
)]
assert
in_dtype
==
self
.
dtype
# get the batch index
with
self
.
_mutex
:
if
batch_id
in
self
.
_free_list
:
self
.
_free_list
.
remove
(
batch_id
)
# get the subtensor
subtensor
=
self
.
_tensor
.
slice
([
batch_id
],
[
batch_id
+
1
])
if
subtensor
.
device_type
==
LiteDeviceType
.
LITE_CPU
:
...
...
@@ -77,10 +87,6 @@ class TensorBatchCollector:
assert
list
(
shape
)
==
self
.
shape
[
1
:]
in_dtype
=
array
.
layout
.
data_type
assert
in_dtype
==
self
.
dtype
# get the batch index
with
self
.
_mutex
:
if
batch_id
in
self
.
_free_list
:
self
.
_free_list
.
remove
(
batch_id
)
# get the subtensor
subtensor
=
self
.
_tensor
.
slice
([
batch_id
],
[
batch_id
+
1
])
subtensor
.
copy_from
(
array
)
...
...
@@ -90,9 +96,12 @@ class TensorBatchCollector:
def
collect
(
self
,
array
):
with
self
.
_mutex
:
if
len
(
self
.
_free_list
)
==
0
:
warnings
.
warn
(
"all batch has been collected, please call free before collect again."
)
return
-
1
idx
=
self
.
_free_list
.
pop
(
0
)
return
self
.
collect
_id
(
array
,
idx
)
return
self
.
_collect_with
_id
(
array
,
idx
)
def
collect_by_ctypes
(
self
,
data
,
length
):
"""
...
...
@@ -115,6 +124,12 @@ class TensorBatchCollector:
def
free
(
self
,
indexes
):
with
self
.
_mutex
:
for
i
in
indexes
:
if
i
in
self
.
_free_list
:
warnings
.
warn
(
"batch id {} has not collected before free it."
.
format
(
i
)
)
self
.
_free_list
.
remove
(
i
)
self
.
_free_list
.
extend
(
indexes
)
def
get
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录