Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
99deaf5c
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
99deaf5c
编写于
11月 19, 2019
作者:
Y
yhz
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify shards for v0.5.3
上级
8733063d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
89 addition
and
28 deletion
+89
-28
shards/mishards/connections.py
shards/mishards/connections.py
+10
-0
shards/mishards/service_handler.py
shards/mishards/service_handler.py
+79
-28
未找到文件。
shards/mishards/connections.py
浏览文件 @
99deaf5c
...
...
@@ -2,6 +2,7 @@ import logging
import
threading
from
functools
import
wraps
from
milvus
import
Milvus
from
milvus.client.hooks
import
BaseaSearchHook
from
mishards
import
(
settings
,
exceptions
)
from
utils
import
singleton
...
...
@@ -9,6 +10,12 @@ from utils import singleton
logger
=
logging
.
getLogger
(
__name__
)
class
Searchook
(
BaseaSearchHook
):
def
on_response
(
self
,
*
args
,
**
kwargs
):
return
True
class
Connection
:
def
__init__
(
self
,
name
,
uri
,
max_retry
=
1
,
error_handlers
=
None
,
**
kwargs
):
self
.
name
=
name
...
...
@@ -18,6 +25,9 @@ class Connection:
self
.
conn
=
Milvus
()
self
.
error_handlers
=
[]
if
not
error_handlers
else
error_handlers
self
.
on_retry_func
=
kwargs
.
get
(
'on_retry_func'
,
None
)
# define search hook
self
.
conn
.
_set_hook
(
search_in_file
=
Searchook
())
# self._connect()
def
__str__
(
self
):
...
...
shards/mishards/service_handler.py
浏览文件 @
99deaf5c
...
...
@@ -29,39 +29,88 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
self
.
router
=
router
self
.
max_workers
=
max_workers
def
_reduce
(
self
,
source_ids
,
ids
,
source_diss
,
diss
,
k
,
reverse
):
if
source_diss
[
k
-
1
]
<=
diss
[
0
]:
return
source_ids
,
source_diss
if
diss
[
k
-
1
]
<=
source_diss
[
0
]:
return
ids
,
diss
diss_t
=
enumerate
(
source_diss
.
extend
(
diss
))
diss_m_rst
=
sorted
(
diss_t
,
key
=
lambda
x
:
x
[
1
])[:
k
]
diss_m_out
=
[
id_
for
_
,
id_
in
diss_m_rst
]
id_t
=
source_ids
.
extend
(
ids
)
id_m_out
=
[
id_t
[
i
]
for
i
,
_
in
diss_m_rst
]
return
id_m_out
,
diss_m_out
def
_do_merge
(
self
,
files_n_topk_results
,
topk
,
reverse
=
False
,
**
kwargs
):
status
=
status_pb2
.
Status
(
error_code
=
status_pb2
.
SUCCESS
,
reason
=
"Success"
)
if
not
files_n_topk_results
:
return
status
,
[]
request_results
=
defaultdict
(
list
)
# request_results = defaultdict(list)
# row_num = files_n_topk_results[0].row_num
merge_id_results
=
[]
merge_dis_results
=
[]
calc_time
=
time
.
time
()
for
files_collection
in
files_n_topk_results
:
if
isinstance
(
files_collection
,
tuple
):
status
,
_
=
files_collection
return
status
,
[]
for
request_pos
,
each_request_results
in
enumerate
(
files_collection
.
topk_query_result
):
request_results
[
request_pos
].
extend
(
each_request_results
.
query_result_arrays
)
request_results
[
request_pos
]
=
sorted
(
request_results
[
request_pos
],
key
=
lambda
x
:
x
.
distance
,
reverse
=
reverse
)[:
topk
]
row_num
=
files_collection
.
row_num
ids
=
files_collection
.
ids
diss
=
files_collection
.
distances
# distance collections
batch_len
=
len
(
ids
)
//
row_num
for
row_index
in
range
(
row_num
):
id_batch
=
ids
[
row_index
*
batch_len
:
(
row_index
+
1
)
*
batch_len
]
dis_batch
=
diss
[
row_index
*
batch_len
:
(
row_index
+
1
)
*
batch_len
]
if
len
(
merge_id_results
)
<
row_index
:
raise
ValueError
(
"merge error"
)
elif
len
(
merge_id_results
)
==
row_index
:
# TODO: may bug here
merge_id_results
.
append
(
id_batch
)
merge_dis_results
.
append
(
dis_batch
)
else
:
merge_id_results
[
row_index
].
extend
(
ids
[
row_index
*
batch_len
,
(
row_index
+
1
)
*
batch_len
])
merge_dis_results
[
row_index
].
extend
(
diss
[
row_index
*
batch_len
,
(
row_index
+
1
)
*
batch_len
])
# _reduce(_ids, _diss, k, reverse)
merge_id_results
[
row_index
],
merge_dis_results
[
row_index
]
=
\
self
.
_reduce
(
merge_id_results
[
row_index
],
id_batch
,
merge_dis_results
[
row_index
],
dis_batch
,
batch_len
,
reverse
)
# for request_pos, each_request_results in enumerate(
# files_collection.topk_query_result):
# request_results[request_pos].extend(
# each_request_results.query_result_arrays)
# request_results[request_pos] = sorted(
# request_results[request_pos],
# key=lambda x: x.distance,
# reverse=reverse)[:topk]
calc_time
=
time
.
time
()
-
calc_time
logger
.
info
(
'Merge takes {}'
.
format
(
calc_time
))
results
=
sorted
(
request_results
.
items
())
topk_query_result
=
[]
# results = sorted(request_results.items())
id_mrege_list
=
[]
dis_mrege_list
=
[]
for
id_results
,
dis_results
in
zip
(
merge_id_results
,
merge_dis_results
):
id_mrege_list
.
extend
(
id_results
)
dis_mrege_list
.
extend
(
dis_results
)
for
result
in
results
:
query_result
=
TopKQueryResult
(
query_result_arrays
=
result
[
1
])
topk_query_result
.
append
(
query_result
)
#
for result in results:
#
query_result = TopKQueryResult(query_result_arrays=result[1])
#
topk_query_result.append(query_result)
return
status
,
topk_query_resul
t
return
status
,
id_mrege_list
,
dis_mrege_lis
t
def
_do_query
(
self
,
context
,
...
...
@@ -109,8 +158,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
file_ids
=
query_params
[
'file_ids'
],
query_records
=
vectors
,
top_k
=
topk
,
nprobe
=
nprobe
,
lazy_
=
True
)
nprobe
=
nprobe
)
end
=
time
.
time
()
logger
.
info
(
'search_vectors_in_files takes: {}'
.
format
(
end
-
start
))
...
...
@@ -241,7 +290,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger
.
info
(
'Search {}: topk={} nprobe={}'
.
format
(
table_name
,
topk
,
nprobe
))
metadata
=
{
'resp_class'
:
milvus_pb2
.
TopKQueryResult
List
}
metadata
=
{
'resp_class'
:
milvus_pb2
.
TopKQueryResult
}
if
nprobe
>
self
.
MAX_NPROBE
or
nprobe
<=
0
:
raise
exceptions
.
InvalidArgumentError
(
...
...
@@ -275,22 +324,24 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
query_range_array
.
append
(
Range
(
query_range
.
start_value
,
query_range
.
end_value
))
status
,
results
=
self
.
_do_query
(
context
,
table_name
,
table_meta
,
query_record_array
,
topk
,
nprobe
,
query_range_array
,
metadata
=
metadata
)
status
,
id_results
,
dis_
results
=
self
.
_do_query
(
context
,
table_name
,
table_meta
,
query_record_array
,
topk
,
nprobe
,
query_range_array
,
metadata
=
metadata
)
now
=
time
.
time
()
logger
.
info
(
'SearchVector takes: {}'
.
format
(
now
-
start
))
topk_result_list
=
milvus_pb2
.
TopKQueryResult
List
(
topk_result_list
=
milvus_pb2
.
TopKQueryResult
(
status
=
status_pb2
.
Status
(
error_code
=
status
.
error_code
,
reason
=
status
.
reason
),
topk_query_result
=
results
)
row_num
=
len
(
query_record_array
),
ids
=
id_results
,
distances
=
dis_results
)
return
topk_result_list
@
mark_grpc_method
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录