Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
6cd62518
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6cd62518
编写于
8月 29, 2022
作者:
W
Walter
提交者:
GitHub
8月 29, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2236 from RainFrost1/lite_shitu
update build_gallery and add android demo index support
上级
3a4c7861
52ba23c8
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
119 addition
and
81 deletion
+119
-81
deploy/python/build_gallery.py
deploy/python/build_gallery.py
+119
-81
未找到文件。
deploy/python/build_gallery.py
浏览文件 @
6cd62518
...
@@ -12,16 +12,14 @@
...
@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
pickle
import
cv2
import
cv2
import
faiss
import
faiss
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
import
pickle
from
paddleclas.deploy.utils
import
logger
,
config
from
paddleclas.deploy.python.predict_rec
import
RecPredictor
from
paddleclas.deploy.python.predict_rec
import
RecPredictor
from
paddleclas.deploy.python.predict_rec
import
RecPredictor
from
paddleclas.deploy.utils
import
config
,
logger
from
tqdm
import
tqdm
def
split_datafile
(
data_file
,
image_root
,
delimiter
=
"
\t
"
):
def
split_datafile
(
data_file
,
image_root
,
delimiter
=
"
\t
"
):
...
@@ -52,6 +50,7 @@ class GalleryBuilder(object):
...
@@ -52,6 +50,7 @@ class GalleryBuilder(object):
self
.
config
=
config
self
.
config
=
config
self
.
rec_predictor
=
RecPredictor
(
config
)
self
.
rec_predictor
=
RecPredictor
(
config
)
assert
'IndexProcess'
in
config
.
keys
(),
"Index config not found ... "
assert
'IndexProcess'
in
config
.
keys
(),
"Index config not found ... "
self
.
android_demo
=
config
[
"Global"
].
get
(
"android_demo"
,
False
)
self
.
build
(
config
[
'IndexProcess'
])
self
.
build
(
config
[
'IndexProcess'
])
def
build
(
self
,
config
):
def
build
(
self
,
config
):
...
@@ -70,10 +69,86 @@ class GalleryBuilder(object):
...
@@ -70,10 +69,86 @@ class GalleryBuilder(object):
"new"
,
"remove"
,
"append"
"new"
,
"remove"
,
"append"
],
"Only append, remove and new operation are supported"
],
"Only append, remove and new operation are supported"
if
self
.
android_demo
:
self
.
_create_index_for_android_demo
(
config
,
gallery_features
,
gallery_docs
)
return
# vector.index: faiss index file
# vector.index: faiss index file
# id_map.pkl: use this file to map id to image_doc
# id_map.pkl: use this file to map id to image_doc
index
,
ids
=
None
,
None
if
operation_method
in
[
"remove"
,
"append"
]:
if
operation_method
in
[
"remove"
,
"append"
]:
# if remove or append, vector.index and id_map.pkl must exist
# if remove or append, load vector.index and id_map.pkl
index
,
ids
=
self
.
_load_index
(
config
)
index_method
=
config
.
get
(
"index_method"
,
"HNSW32"
)
else
:
index_method
,
index
,
ids
=
self
.
_create_index
(
config
)
if
index_method
==
"HNSW32"
:
logger
.
warning
(
"The HNSW32 method dose not support 'remove' operation"
)
if
operation_method
!=
"remove"
:
# calculate id for new data
index
,
ids
=
self
.
_add_gallery
(
index
,
ids
,
gallery_features
,
gallery_docs
,
config
,
operation_method
)
else
:
if
index_method
==
"HNSW32"
:
raise
RuntimeError
(
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
index
,
ids
=
self
.
_rm_id_in_galllery
(
index
,
ids
,
gallery_docs
)
# store faiss index file and id_map file
self
.
_save_gallery
(
config
,
index
,
ids
)
def
_create_index_for_android_demo
(
self
,
config
,
gallery_features
,
gallery_docs
):
if
not
os
.
path
.
exists
(
config
[
"index_dir"
]):
os
.
makedirs
(
config
[
"index_dir"
],
exist_ok
=
True
)
#build index
index
=
faiss
.
IndexFlatIP
(
config
[
"embedding_size"
])
index
.
add
(
gallery_features
)
# calculate id for data
ids_now
=
(
np
.
arange
(
0
,
len
(
gallery_docs
))).
astype
(
np
.
int64
)
ids
=
{}
for
i
,
d
in
zip
(
list
(
ids_now
),
gallery_docs
):
ids
[
i
]
=
d
self
.
_save_gallery
(
config
,
index
,
ids
)
def
_extract_features
(
self
,
gallery_images
,
config
):
# extract gallery features
if
config
[
"dist_type"
]
==
"hamming"
:
gallery_features
=
np
.
zeros
(
[
len
(
gallery_images
),
config
[
'embedding_size'
]
//
8
],
dtype
=
np
.
uint8
)
else
:
gallery_features
=
np
.
zeros
(
[
len
(
gallery_images
),
config
[
'embedding_size'
]],
dtype
=
np
.
float32
)
#construct batch imgs and do inference
batch_size
=
config
.
get
(
"batch_size"
,
32
)
batch_img
=
[]
for
i
,
image_file
in
enumerate
(
tqdm
(
gallery_images
)):
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
error
(
"img empty, please check {}"
.
format
(
image_file
))
exit
()
img
=
img
[:,
:,
::
-
1
]
batch_img
.
append
(
img
)
if
(
i
+
1
)
%
batch_size
==
0
:
rec_feat
=
self
.
rec_predictor
.
predict
(
batch_img
)
gallery_features
[
i
-
batch_size
+
1
:
i
+
1
,
:]
=
rec_feat
batch_img
=
[]
if
len
(
batch_img
)
>
0
:
rec_feat
=
self
.
rec_predictor
.
predict
(
batch_img
)
gallery_features
[
-
len
(
batch_img
):,
:]
=
rec_feat
batch_img
=
[]
return
gallery_features
def
_load_index
(
self
,
config
):
assert
os
.
path
.
join
(
assert
os
.
path
.
join
(
config
[
"index_dir"
],
"vector.index"
config
[
"index_dir"
],
"vector.index"
),
"The vector.index dose not exist in {} when 'index_operation' is not None"
.
format
(
),
"The vector.index dose not exist in {} when 'index_operation' is not None"
.
format
(
...
@@ -89,7 +164,9 @@ class GalleryBuilder(object):
...
@@ -89,7 +164,9 @@ class GalleryBuilder(object):
ids
=
pickle
.
load
(
fd
)
ids
=
pickle
.
load
(
fd
)
assert
index
.
ntotal
==
len
(
ids
.
keys
(
assert
index
.
ntotal
==
len
(
ids
.
keys
(
)),
"data number in index is not equal in in id_map"
)),
"data number in index is not equal in in id_map"
else
:
return
index
,
ids
def
_create_index
(
self
,
config
):
if
not
os
.
path
.
exists
(
config
[
"index_dir"
]):
if
not
os
.
path
.
exists
(
config
[
"index_dir"
]):
os
.
makedirs
(
config
[
"index_dir"
],
exist_ok
=
True
)
os
.
makedirs
(
config
[
"index_dir"
],
exist_ok
=
True
)
index_method
=
config
.
get
(
"index_method"
,
"HNSW32"
)
index_method
=
config
.
get
(
"index_method"
,
"HNSW32"
)
...
@@ -116,16 +193,12 @@ class GalleryBuilder(object):
...
@@ -116,16 +193,12 @@ class GalleryBuilder(object):
index_method
,
dist_type
)
index_method
,
dist_type
)
index
=
faiss
.
IndexIDMap2
(
index
)
index
=
faiss
.
IndexIDMap2
(
index
)
ids
=
{}
ids
=
{}
return
index_method
,
index
,
ids
if
config
[
"index_method"
]
==
"HNSW32"
:
def
_add_gallery
(
self
,
index
,
ids
,
gallery_features
,
gallery_docs
,
config
,
operation_method
):
logger
.
warning
(
"The HNSW32 method dose not support 'remove' operation"
)
if
operation_method
!=
"remove"
:
# calculate id for new data
start_id
=
max
(
ids
.
keys
())
+
1
if
ids
else
0
start_id
=
max
(
ids
.
keys
())
+
1
if
ids
else
0
ids_now
=
(
ids_now
=
(
np
.
arange
(
0
,
len
(
gallery_image
s
))
+
start_id
).
astype
(
np
.
int64
)
np
.
arange
(
0
,
len
(
gallery_doc
s
))
+
start_id
).
astype
(
np
.
int64
)
# only train when new index file
# only train when new index file
if
operation_method
==
"new"
:
if
operation_method
==
"new"
:
...
@@ -139,12 +212,9 @@ class GalleryBuilder(object):
...
@@ -139,12 +212,9 @@ class GalleryBuilder(object):
for
i
,
d
in
zip
(
list
(
ids_now
),
gallery_docs
):
for
i
,
d
in
zip
(
list
(
ids_now
),
gallery_docs
):
ids
[
i
]
=
d
ids
[
i
]
=
d
else
:
return
index
,
ids
if
config
[
"index_method"
]
==
"HNSW32"
:
raise
RuntimeError
(
def
_rm_id_in_galllery
(
self
,
index
,
ids
,
gallery_docs
):
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
remove_ids
=
list
(
remove_ids
=
list
(
filter
(
lambda
k
:
ids
.
get
(
k
)
in
gallery_docs
,
ids
.
keys
()))
filter
(
lambda
k
:
ids
.
get
(
k
)
in
gallery_docs
,
ids
.
keys
()))
remove_ids
=
np
.
asarray
(
remove_ids
)
remove_ids
=
np
.
asarray
(
remove_ids
)
...
@@ -152,7 +222,9 @@ class GalleryBuilder(object):
...
@@ -152,7 +222,9 @@ class GalleryBuilder(object):
for
k
in
remove_ids
:
for
k
in
remove_ids
:
del
ids
[
k
]
del
ids
[
k
]
# store faiss index file and id_map file
return
index
,
ids
def
_save_gallery
(
self
,
config
,
index
,
ids
):
if
config
[
"dist_type"
]
==
"hamming"
:
if
config
[
"dist_type"
]
==
"hamming"
:
faiss
.
write_index_binary
(
faiss
.
write_index_binary
(
index
,
os
.
path
.
join
(
config
[
"index_dir"
],
"vector.index"
))
index
,
os
.
path
.
join
(
config
[
"index_dir"
],
"vector.index"
))
...
@@ -163,40 +235,6 @@ class GalleryBuilder(object):
...
@@ -163,40 +235,6 @@ class GalleryBuilder(object):
with
open
(
os
.
path
.
join
(
config
[
"index_dir"
],
"id_map.pkl"
),
'wb'
)
as
fd
:
with
open
(
os
.
path
.
join
(
config
[
"index_dir"
],
"id_map.pkl"
),
'wb'
)
as
fd
:
pickle
.
dump
(
ids
,
fd
)
pickle
.
dump
(
ids
,
fd
)
def
_extract_features
(
self
,
gallery_images
,
config
):
# extract gallery features
if
config
[
"dist_type"
]
==
"hamming"
:
gallery_features
=
np
.
zeros
(
[
len
(
gallery_images
),
config
[
'embedding_size'
]
//
8
],
dtype
=
np
.
uint8
)
else
:
gallery_features
=
np
.
zeros
(
[
len
(
gallery_images
),
config
[
'embedding_size'
]],
dtype
=
np
.
float32
)
#construct batch imgs and do inference
batch_size
=
config
.
get
(
"batch_size"
,
32
)
batch_img
=
[]
for
i
,
image_file
in
enumerate
(
tqdm
(
gallery_images
)):
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
error
(
"img empty, please check {}"
.
format
(
image_file
))
exit
()
img
=
img
[:,
:,
::
-
1
]
batch_img
.
append
(
img
)
if
(
i
+
1
)
%
batch_size
==
0
:
rec_feat
=
self
.
rec_predictor
.
predict
(
batch_img
)
gallery_features
[
i
-
batch_size
+
1
:
i
+
1
,
:]
=
rec_feat
batch_img
=
[]
if
len
(
batch_img
)
>
0
:
rec_feat
=
self
.
rec_predictor
.
predict
(
batch_img
)
gallery_features
[
-
len
(
batch_img
):,
:]
=
rec_feat
batch_img
=
[]
return
gallery_features
def
main
(
config
):
def
main
(
config
):
GalleryBuilder
(
config
)
GalleryBuilder
(
config
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录