Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
508131e4
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
508131e4
编写于
6月 10, 2019
作者:
Y
Yang Xuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(python): complete sdk 0.0.1
Former-commit-id: 713b3a629fcf86291d1e35c1d1b038b6fc599cb8
上级
964d3f37
变更
6
展开全部
显示空白变更内容
内联
并排
Showing
6 changed file
with
353 addition
and
727 deletion
+353
-727
python/sdk/client/Abstract.py
python/sdk/client/Abstract.py
+89
-235
python/sdk/client/Client.py
python/sdk/client/Client.py
+106
-269
python/sdk/client/Exceptions.py
python/sdk/client/Exceptions.py
+0
-4
python/sdk/client/Status.py
python/sdk/client/Status.py
+8
-6
python/sdk/examples/connection_exp.py
python/sdk/examples/connection_exp.py
+69
-56
python/sdk/tests/TestClient.py
python/sdk/tests/TestClient.py
+81
-157
未找到文件。
python/sdk/client/Abstract.py
浏览文件 @
508131e4
from
enum
import
IntEnum
from
.Exceptions
import
ConnectParamMissingError
class
AbstactIndexType
(
object
):
RAW
=
'1'
IVFFLAT
=
'2'
class
AbstractColumnType
(
object
):
INVALID
=
1
INT8
=
2
INT16
=
3
INT32
=
4
INT64
=
5
FLOAT32
=
6
FLOAT64
=
7
DATE
=
8
VECTOR
=
9
class
Column
(
object
):
"""
Table column description
:type type: ColumnType
:param type: (Required) type of the column
:type name: str
:param name: (Required) name of the column
"""
def
__init__
(
self
,
name
=
None
,
type
=
AbstractColumnType
.
INVALID
):
self
.
type
=
type
self
.
name
=
name
class
VectorColumn
(
Column
):
"""
Table vector column description
:type dimension: int, int64
:param dimension: (Required) vector dimension
:type index_type: string IndexType
:param index_type: (Required) IndexType
:type store_raw_vector: bool
:param store_raw_vector: (Required) Is vector self stored in the table
`Column`:
:type name: str
:param name: (Required) Name of the column
:type type: ColumnType
:param type: (Required) Default type is ColumnType.VECTOR, can't change
"""
def
__init__
(
self
,
name
,
dimension
=
0
,
index_type
=
None
,
store_raw_vector
=
False
,
type
=
None
):
self
.
dimension
=
dimension
self
.
index_type
=
index_type
self
.
store_raw_vector
=
store_raw_vector
super
(
VectorColumn
,
self
).
__init__
(
name
,
type
=
type
)
class
IndexType
(
IntEnum
):
INVALIDE
=
0
IDMAP
=
1
IVFLAT
=
2
class
TableSchema
(
object
):
...
...
@@ -74,30 +14,26 @@ class TableSchema(object):
:type table_name: str
:param table_name: (Required) name of table
:type vector_columns: list[VectorColumn]
:param vector_columns: (Required) a list of VectorColumns,
Stores different types of vectors
:type attribute_columns: list[Column]
:param attribute_columns: (Optional) Columns description
:type index_type: IndexType
:param index_type: (Optional) index type, default = 0
List of `Columns` whose type isn't VECTOR
`IndexType`: 0-invalid, 1-idmap, 2-ivflat
:type
partition_column_names: list[str]
:param
partition_column_names: (Optional) Partition column name
:type
dimension: int64
:param
dimension: (Required) dimension of vector
`Partition columns` are `attribute columns`, the number of
partition columns may be less than or equal to attribute columns,
this param only stores `column name`
:type store_raw_vector: bool
:param store_raw_vector: (Optional) default = False
"""
def
__init__
(
self
,
table_name
,
vector_columns
,
attribute_columns
,
partition_column_names
,
**
kwargs
):
def
__init__
(
self
,
table_name
,
dimension
=
0
,
index_type
=
IndexType
.
INVALIDE
,
store_raw_vector
=
False
):
self
.
table_name
=
table_name
self
.
vector_columns
=
vector_columns
self
.
attribute_columns
=
attribute_columns
self
.
partition_column_names
=
partition_column_names
self
.
index_type
=
index_type
self
.
dimension
=
dimension
self
.
store_raw_vector
=
store_raw_vector
class
Range
(
object
):
...
...
@@ -105,10 +41,10 @@ class Range(object):
Range information
:type start: str
:param start:
(Required)
Range start value
:param start: Range start value
:type end: str
:param end:
(Required)
Range end value
:param end: Range end value
"""
def
__init__
(
self
,
start
,
end
):
...
...
@@ -116,97 +52,37 @@ class Range(object):
self
.
end
=
end
class
CreateTablePartitionParam
(
object
):
"""
Create table partition parameters
:type table_name: str
:param table_name: (Required) Table name,
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
:type partition_name: str
:param partition_name: (Required) partition name, created partition name
:type column_name_to_range: dict{str : Range}
:param column_name_to_range: (Required) Column name to PartitionRange dictionary
"""
# TODO Iterable
def
__init__
(
self
,
table_name
,
partition_name
,
column_name_to_range
):
self
.
table_name
=
table_name
self
.
partition_name
=
partition_name
self
.
column_name_to_range
=
column_name_to_range
class
DeleteTablePartitionParam
(
object
):
"""
Delete table partition parameters
:type table_name: str
:param table_name: (Required) Table name
:type partition_names: iterable, str
:param partition_names: (Required) Partition name array
"""
# TODO Iterable
def
__init__
(
self
,
table_name
,
partition_names
):
self
.
table_name
=
table_name
self
.
partition_names
=
partition_names
class
RowRecord
(
object
):
"""
Record inserted
:type column_name_to_vector: dict{str : list[float]}
:param column_name_to_vector: (Required) Column name to vector map
:type column_name_to_attribute: dict{str: str}
:param column_name_to_attribute: (Optional) Other attribute columns
"""
def
__init__
(
self
,
column_name_to_vector
,
column_name_to_attribute
):
self
.
column_name_to_vector
=
column_name_to_vector
self
.
column_name_to_attribute
=
column_name_to_attribute
:type vector_data: binary str
:param vector_data: (Required) a vector
class
QueryRecord
(
object
):
"""
Query record
:type column_name_to_vector: (Required) dict{str : list[float]}
:param column_name_to_vector: Query vectors, column name to vector map
:type selected_columns: list[str]
:param selected_columns: (Optional) Output column array
:type name_to_partition_ranges: dict{str : list[Range]}
:param name_to_partition_ranges: (Optional) Range used to select partitions
"""
def
__init__
(
self
,
column_name_to_vector
,
selected_columns
,
name_to_partition_ranges
):
self
.
column_name_to_vector
=
column_name_to_vector
self
.
selected_columns
=
selected_columns
self
.
name_to_partition_ranges
=
name_to_partition_ranges
def
__init__
(
self
,
vector_data
):
self
.
vector_data
=
vector_data
class
QueryResult
(
object
):
"""
Query result
:type id: int
:param id:
Output result
:type id: int
64
:param id:
id of the vector
:type score: float
:param score: Vector similarity 0 <= score <= 100
:type column_name_to_attribute: dict{str : str}
:param column_name_to_attribute: Other columns
"""
def
__init__
(
self
,
id
,
score
,
column_name_to_attribute
):
def
__init__
(
self
,
id
,
score
):
self
.
id
=
id
self
.
score
=
score
self
.
column_name_to_value
=
column_name_to_attribute
def
__repr__
(
self
):
L
=
[
'%s=%r'
%
(
key
,
value
)
for
key
,
value
in
self
.
__dict__
.
items
()]
return
'%s(%s)'
%
(
self
.
__class__
.
__name__
,
', '
.
join
(
L
))
class
TopKQueryResult
(
object
):
...
...
@@ -220,6 +96,12 @@ class TopKQueryResult(object):
def
__init__
(
self
,
query_results
):
self
.
query_results
=
query_results
def
__repr__
(
self
):
L
=
[
'%s=%r'
%
(
key
,
value
)
for
key
,
value
in
self
.
__dict__
.
items
()]
return
'%s(%s)'
%
(
self
.
__class__
.
__name__
,
', '
.
join
(
L
))
def
_abstract
():
raise
NotImplementedError
(
'You need to override this function'
)
...
...
@@ -232,114 +114,71 @@ class ConnectIntf(object):
"""
@
staticmethod
def
create
():
"""Create a connection instance and return it
should be implemented
:return connection: Connection
"""
_abstract
()
@
staticmethod
def
destroy
(
connection
):
"""Destroy the connection instance
should be implemented
:type connection: Connection
:param connection: The connection instance to be destroyed
:return bool, return True if destroy is successful
"""
_abstract
()
def
connect
(
self
,
param
=
None
,
uri
=
None
):
def
connect
(
self
,
host
=
None
,
port
=
None
,
uri
=
None
):
"""
Connect method should be called before any operations
Server will be connected after connect return OK
should be implemented
Should be implemented
:type host: str
:param host: host
:type p
aram: ConnectParam
:param p
aram: ConnectParam
:type p
ort: str
:param p
ort: port
:type uri: str
:param uri:
uri param
:param uri:
(Optional) uri
:return
:
Status, indicate if connect is successful
:return Status, indicate if connect is successful
"""
if
(
not
param
and
not
uri
)
or
(
param
and
uri
):
raise
ConnectParamMissingError
(
'You need to parse exact one param'
)
_abstract
()
def
connected
(
self
):
"""
connected, connection status
s
hould be implemented
S
hould be implemented
:return
:
Status, indicate if connect is successful
:return Status, indicate if connect is successful
"""
_abstract
()
def
disconnect
(
self
):
"""
Disconnect, server will be disconnected after disconnect return
OK
s
hould be implemented
Disconnect, server will be disconnected after disconnect return
SUCCESS
S
hould be implemented
:return
:
Status, indicate if connect is successful
:return Status, indicate if connect is successful
"""
_abstract
()
def
create_table
(
self
,
param
):
"""
Create table
s
hould be implemented
S
hould be implemented
:type param: TableSchema
:param param: provide table information to be created
:return
:
Status, indicate if connect is successful
:return Status, indicate if connect is successful
"""
_abstract
()
def
delete_table
(
self
,
table_name
):
"""
Delete table
s
hould be implemented
S
hould be implemented
:type table_name: str
:param table_name: table_name of the deleting table
:return
:
Status, indicate if connect is successful
:return Status, indicate if connect is successful
"""
_abstract
()
def
create_table_partition
(
self
,
param
):
"""
Create table partition
should be implemented
:type param: CreateTablePartitionParam
:param param: provide partition information
:return: Status, indicate if table partition is created successfully
"""
_abstract
()
def
delete_table_partition
(
self
,
param
):
"""
Delete table partition
should be implemented
:type param: DeleteTablePartitionParam
:param param: provide partition information to be deleted
:return: Status, indicate if partition is deleted successfully
"""
_abstract
()
def
add_vector
(
self
,
table_name
,
records
):
def
add_vectors
(
self
,
table_name
,
records
):
"""
Add vectors to table
s
hould be implemented
S
hould be implemented
:type table_name: str
:param table_name: table name been inserted
...
...
@@ -347,27 +186,31 @@ class ConnectIntf(object):
:type records: list[RowRecord]
:param records: list of vectors been inserted
:returns
:
:returns
Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id
"""
_abstract
()
def
search_vector
(
self
,
table_name
,
query_record
s
,
top_k
):
def
search_vector
s
(
self
,
table_name
,
query_records
,
query_range
s
,
top_k
):
"""
Query vectors in a table
s
hould be implemented
S
hould be implemented
:type table_name: str
:param table_name: table name been queried
:type query_records: list[
Query
Record]
:type query_records: list[
Row
Record]
:param query_records: all vectors going to be queried
:type query_ranges: list[Range]
:param query_ranges: Optional ranges for conditional search.
If not specified, search whole table
:type top_k: int
:param top_k: how many similar vectors will be searched
:returns
:
:returns
Status: indicate if query is successful
query_results: list[TopKQueryResult]
"""
...
...
@@ -376,23 +219,37 @@ class ConnectIntf(object):
def
describe_table
(
self
,
table_name
):
"""
Show table information
s
hould be implemented
S
hould be implemented
:type table_name: str
:param table_name: which table to be shown
:returns
:
:returns
Status: indicate if query is successful
table_schema: TableSchema, given when operation is successful
"""
_abstract
()
def
get_table_row_count
(
self
,
table_name
):
"""
Get table row count
Should be implemented
:type table_name, str
:param table_name, target table name.
:returns
Status: indicate if operation is successful
count: int, table row count
"""
_abstract
()
def
show_tables
(
self
):
"""
Show all tables in database
should be implemented
:return
:
:return
Status: indicate if this operation is successful
tables: list[str], list of table names
"""
...
...
@@ -403,31 +260,28 @@ class ConnectIntf(object):
Provide client version
should be implemented
:return:
C
lient version
:return:
str, c
lient version
"""
_abstract
()
pass
def
server_version
(
self
):
"""
Provide server version
should be implemented
:return:
S
erver version
:return:
str, s
erver version
"""
_abstract
()
def
server_status
(
self
,
cmd
):
"""
Provide server status
should be implemented
# TODO What is cmd
:type cmd
:param cmd
:type cmd, str
:return:
S
erver status
:return:
str, s
erver status
"""
_abstract
()
pass
...
...
python/sdk/client/Client.py
浏览文件 @
508131e4
此差异已折叠。
点击以展开。
python/sdk/client/Exceptions.py
浏览文件 @
508131e4
...
...
@@ -2,10 +2,6 @@ class ParamError(ValueError):
pass
class
ConnectParamMissingError
(
ParamError
):
pass
class
ConnectError
(
ValueError
):
pass
...
...
python/sdk/client/Status.py
浏览文件 @
508131e4
...
...
@@ -3,13 +3,15 @@ class Status(object):
:attribute code : int (optional) default as ok
:attribute message : str (optional) current status message
"""
OK
=
0
INVALID
=
1
UNKNOWN_ERROR
=
2
NOT_SUPPORTED
=
3
NOT_CONNECTED
=
4
SUCCESS
=
0
CONNECT_FAILED
=
1
PERMISSION_DENIED
=
2
TABLE_NOT_EXISTS
=
3
ILLEGAL_ARGUMENT
=
4
ILLEGAL_RANGE
=
5
ILLEGAL_DIMENSION
=
6
def
__init__
(
self
,
code
=
OK
,
message
=
None
):
def
__init__
(
self
,
code
=
SUCCESS
,
message
=
None
):
self
.
code
=
code
self
.
message
=
message
...
...
python/sdk/examples/connection_exp.py
浏览文件 @
508131e4
from
client.Client
import
MegaSearch
,
Prepare
,
IndexType
,
ColumnType
from
client.Client
import
MegaSearch
,
Prepare
,
IndexType
from
client.Status
import
Status
import
time
import
random
import
struct
from
pprint
import
pprint
from
megasearch.thrift
import
MegasearchService
,
ttypes
def
main
():
# Get client version
mega
=
MegaSearch
()
print
(
mega
.
client_version
(
))
print
(
'# Client version: {}'
.
format
(
mega
.
client_version
()
))
# Connect
param
=
{
'host'
:
'192.168.1.129'
,
'port'
:
'33001'
}
cnn_status
=
mega
.
connect
(
**
param
)
print
(
'Connect Status: {}'
.
format
(
cnn_status
))
print
(
'
#
Connect Status: {}'
.
format
(
cnn_status
))
# Check if connected
is_connected
=
mega
.
connected
print
(
'Connect status: {}'
.
format
(
is_connected
))
# Create table with 1 vector column, 1 attribute column and 1 partition column
# 1. prepare table_schema
# table_schema = Prepare.table_schema(
# table_name='fake_table_name' + time.strftime('%H%M%S'),
#
# vector_columns=[Prepare.vector_column(
# name='fake_vector_name' + time.strftime('%H%M%S'),
# store_raw_vector=False,
# dimension=256)],
#
# attribute_columns=[],
#
# partition_column_names=[]
# )
# get server version
print
(
mega
.
server_status
(
'version'
))
print
(
mega
.
client
.
Ping
(
'version'
))
# show tables and their description
statu
,
tables
=
mega
.
show_tables
()
print
(
tables
)
for
table
in
tables
:
s
,
t
=
mega
.
describe_table
(
table
)
print
(
'table: {}'
.
format
(
t
))
print
(
'# Is connected: {}'
.
format
(
is_connected
))
# Create table
# 1. create table schema
table_schema_full
=
MegasearchService
.
TableSchema
(
table_name
=
'fake'
+
time
.
strftime
(
'%H%M%S'
),
vector_column_array
=
[
MegasearchService
.
VectorColumn
(
base
=
MegasearchService
.
Column
(
name
=
'111'
,
type
=
ttypes
.
TType
.
LIST
),
index_type
=
"aaa"
,
dimension
=
256
,
store_raw_vector
=
False
,
)],
attribute_column_array
=
[],
# Get server version
print
(
'# Server version: {}'
.
format
(
mega
.
server_version
()))
partition_column_name_array
=
[]
)
# Show tables and their description
status
,
tables
=
mega
.
show_tables
()
print
(
'# Show tables: {}'
.
format
(
tables
))
# 2. Create Table
create_status
=
mega
.
client
.
CreateTable
(
param
=
table_schema_full
)
print
(
'Create table status: {}'
.
format
(
create_status
))
# add_vector
# Create table
# 01.Prepare data
param
=
{
'table_name'
:
'test'
+
str
(
random
.
randint
(
0
,
999
)),
'dimension'
:
256
,
'index_type'
:
IndexType
.
IDMAP
,
'store_raw_vector'
:
False
}
# 02.Create table
res_status
=
mega
.
create_table
(
Prepare
.
table_schema
(
**
param
))
print
(
'# Create table status: {}'
.
format
(
res_status
))
# Describe table
table_name
=
'test01'
res_status
,
table
=
mega
.
describe_table
(
table_name
)
print
(
'# Describe table status: {}'
.
format
(
res_status
))
print
(
'# Describe table:{}'
.
format
(
table
))
# Add vectors to table 'test01'
# 01. Prepare data
dim
=
256
# list of binary vectors
vectors
=
[
Prepare
.
row_record
(
struct
.
pack
(
str
(
dim
)
+
'd'
,
*
[
random
.
random
()
for
_
in
range
(
dim
)]))
for
_
in
range
(
20
)]
# 02. Add vectors
status
,
ids
=
mega
.
add_vectors
(
table_name
=
table_name
,
records
=
vectors
)
print
(
'# Add vector status: {}'
.
format
(
status
))
pprint
(
ids
)
# Search vectors
q_records
=
[
Prepare
.
row_record
(
struct
.
pack
(
str
(
dim
)
+
'd'
,
*
[
random
.
random
()
for
_
in
range
(
dim
)]))
for
_
in
range
(
5
)]
param
=
{
'table_name'
:
'test01'
,
'query_records'
:
q_records
,
'top_k'
:
10
,
# 'query_ranges': None # Optional
}
sta
,
results
=
mega
.
search_vectors
(
**
param
)
print
(
'# Search vectors status: {}'
.
format
(
sta
))
pprint
(
results
)
# Get table row count
sta
,
result
=
mega
.
get_table_row_count
(
table_name
)
print
(
'# Status: {}'
.
format
(
sta
))
print
(
'# Count: {}'
.
format
(
result
))
# Delete table 'test01'
res_status
=
mega
.
delete_table
(
table_name
)
print
(
'# Delete table status: {}'
.
format
(
res_status
))
# Disconnect
discnn_status
=
mega
.
disconnect
()
print
(
'
Disconnect Status
{}'
.
format
(
discnn_status
))
print
(
'
# Disconnect Status:
{}'
.
format
(
discnn_status
))
if
__name__
==
'__main__'
:
...
...
python/sdk/tests/TestClient.py
浏览文件 @
508131e4
...
...
@@ -3,17 +3,20 @@ import pytest
import
mock
import
faker
import
random
import
struct
from
faker.providers
import
BaseProvider
from
client.Client
import
MegaSearch
,
Prepare
,
IndexType
,
ColumnType
from
client.Client
import
MegaSearch
,
Prepare
from
client.Abstract
import
IndexType
,
TableSchema
from
client.Status
import
Status
from
client.Exceptions
import
(
RepeatingConnectError
,
DisconnectNotConnectedClientError
)
from
megasearch.thrift
import
ttypes
,
MegasearchService
from
thrift.transport.TSocket
import
TSocket
from
megasearch.thrift
import
ttypes
,
MegasearchService
from
thrift.transport
import
TTransport
from
thrift.transport.TTransport
import
TTransportException
LOGGER
=
logging
.
getLogger
(
__name__
)
...
...
@@ -35,63 +38,37 @@ fake = faker.Faker()
fake
.
add_provider
(
FakerProvider
)
def
vector_column_factory
():
return
{
'name'
:
fake
.
name
(),
'dimension'
:
fake
.
dim
(),
'store_raw_vector'
:
True
}
def
column_factory
():
return
{
'name'
:
fake
.
table_name
(),
'type'
:
ColumnType
.
INT32
}
def
range_factory
():
return
{
param
=
{
'start'
:
str
(
random
.
randint
(
1
,
10
)),
'end'
:
str
(
random
.
randint
(
11
,
20
)),
}
return
Prepare
.
range
(
**
param
)
def
ranges_factory
():
return
[
range_factory
()
for
_
in
range
(
5
)]
def
table_schema_factory
():
vec_params
=
[
vector_column_factory
()
for
i
in
range
(
10
)]
column_params
=
[
column_factory
()
for
i
in
range
(
5
)]
param
=
{
'table_name'
:
fake
.
table_name
(),
'
vector_columns'
:
[
Prepare
.
vector_column
(
**
pa
)
for
pa
in
vec_params
]
,
'
attribute_columns'
:
[
Prepare
.
column
(
**
pa
)
for
pa
in
column_params
]
,
'
partition_column_names'
:
[
str
(
x
)
for
x
in
range
(
2
)]
'
dimension'
:
random
.
randint
(
0
,
999
)
,
'
index_type'
:
IndexType
.
IDMAP
,
'
store_raw_vector'
:
False
}
return
Prepare
.
table_schema
(
**
param
)
def
create_table_partition_param_factory
():
param
=
{
'table_name'
:
fake
.
table_name
(),
'partition_name'
:
fake
.
table_name
(),
'column_name_to_range'
:
{
fake
.
name
():
range_factory
()
for
_
in
range
(
3
)}
}
return
Prepare
.
create_table_partition_param
(
**
param
)
def
row_record_factory
(
dimension
):
vec
=
[
random
.
random
()
+
random
.
randint
(
0
,
9
)
for
_
in
range
(
dimension
)]
bin_vec
=
struct
.
pack
(
str
(
dimension
)
+
"d"
,
*
vec
)
def
delete_table_partition_param_factory
():
param
=
{
'table_name'
:
fake
.
table_name
(),
'partition_names'
:
[
fake
.
name
()
for
i
in
range
(
5
)]
}
return
Prepare
.
delete_table_partition_param
(
**
param
)
return
Prepare
.
row_record
(
vector_data
=
bin_vec
)
def
row_record_factory
():
param
=
{
'column_name_to_vector'
:
{
fake
.
name
():
[
random
.
random
()
for
i
in
range
(
256
)]},
'column_name_to_attribute'
:
{
fake
.
name
():
fake
.
name
()}
}
return
Prepare
.
row_record
(
**
param
)
def
row_records_factory
(
dimension
):
return
[
row_record_factory
(
dimension
)
for
_
in
range
(
20
)]
class
TestConnection
:
...
...
@@ -103,9 +80,8 @@ class TestConnection:
cnn
=
MegaSearch
()
cnn
.
connect
(
**
self
.
param
)
assert
cnn
.
status
==
Status
.
OK
assert
cnn
.
status
==
Status
.
SUCCESS
assert
cnn
.
connected
assert
isinstance
(
cnn
.
client
,
MegasearchService
.
Client
)
with
pytest
.
raises
(
RepeatingConnectError
):
cnn
.
connect
(
**
self
.
param
)
...
...
@@ -114,12 +90,23 @@ class TestConnection:
def
test_false_connect
(
self
):
cnn
=
MegaSearch
()
cnn
.
connect
(
self
.
param
)
assert
cnn
.
status
!=
Status
.
OK
cnn
.
connect
(
**
self
.
param
)
assert
cnn
.
status
!=
Status
.
SUCCESS
@
mock
.
patch
.
object
(
TTransport
.
TBufferedTransport
,
'close'
)
@
mock
.
patch
.
object
(
TSocket
,
'open'
)
def
test_disconnected
(
self
,
close
,
open
):
close
.
return_value
=
None
open
.
return_value
=
None
cnn
=
MegaSearch
()
cnn
.
connect
(
**
self
.
param
)
assert
cnn
.
disconnect
()
==
Status
.
SUCCESS
def
test_disconnected_error
(
self
):
cnn
=
MegaSearch
()
cnn
.
connect_status
=
Status
(
Status
.
INVALI
D
)
cnn
.
connect_status
=
Status
(
Status
.
PERMISSION_DENIE
D
)
with
pytest
.
raises
(
DisconnectNotConnectedClientError
):
cnn
.
disconnect
()
...
...
@@ -142,26 +129,26 @@ class TestTable:
param
=
table_schema_factory
()
res
=
client
.
create_table
(
param
)
assert
res
==
Status
.
OK
assert
res
==
Status
.
SUCCESS
def
test_false_create_table
(
self
,
client
):
param
=
table_schema_factory
()
with
pytest
.
raises
(
TTransportException
):
res
=
client
.
create_table
(
param
)
LOGGER
.
error
(
'{}'
.
format
(
res
))
assert
res
!=
Status
.
OK
assert
res
!=
Status
.
SUCCESS
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'DeleteTable'
)
def
test_delete_table
(
self
,
DeleteTable
,
client
):
DeleteTable
.
return_value
=
None
table_name
=
'fake_table_name'
res
=
client
.
delete_table
(
table_name
)
assert
res
==
Status
.
OK
assert
res
==
Status
.
SUCCESS
def
test_false_delete_table
(
self
,
client
):
table_name
=
'fake_table_name'
res
=
client
.
delete_table
(
table_name
)
assert
res
!=
Status
.
OK
assert
res
!=
Status
.
SUCCESS
class
TestVector
:
...
...
@@ -176,70 +163,46 @@ class TestVector:
cnn
.
connect
(
**
param
)
return
cnn
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'CreateTablePartition'
)
def
test_create_table_partition
(
self
,
CreateTablePartition
,
client
):
CreateTablePartition
.
return_value
=
None
param
=
create_table_partition_param_factory
()
res
=
client
.
create_table_partition
(
param
)
assert
res
==
Status
.
OK
def
test_false_table_partition
(
self
,
client
):
param
=
create_table_partition_param_factory
()
res
=
client
.
create_table_partition
(
param
)
assert
res
!=
Status
.
OK
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'DeleteTablePartition'
)
def
test_delete_table_partition
(
self
,
DeleteTablePartition
,
client
):
DeleteTablePartition
.
return_value
=
None
param
=
delete_table_partition_param_factory
()
res
=
client
.
delete_table_partition
(
param
)
assert
res
==
Status
.
OK
def
test_false_delete_table_partition
(
self
,
client
):
param
=
delete_table_partition_param_factory
()
res
=
client
.
delete_table_partition
(
param
)
assert
res
!=
Status
.
OK
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'AddVector'
)
def
test_add_vector
(
self
,
AddVector
,
client
):
AddVector
.
return_value
=
None
param
=
{
'table_name'
:
fake
.
table_name
(),
'records'
:
[
row_record_factory
()
for
_
in
range
(
1000
)]
'records'
:
row_records_factory
(
256
)
}
res
,
ids
=
client
.
add_vector
(
**
param
)
assert
res
==
Status
.
OK
res
,
ids
=
client
.
add_vector
s
(
**
param
)
assert
res
==
Status
.
SUCCESS
def
test_false_add_vector
(
self
,
client
):
param
=
{
'table_name'
:
fake
.
table_name
(),
'records'
:
[
row_record_factory
()
for
_
in
range
(
1000
)]
'records'
:
row_records_factory
(
256
)
}
res
,
ids
=
client
.
add_vector
(
**
param
)
assert
res
!=
Status
.
OK
res
,
ids
=
client
.
add_vector
s
(
**
param
)
assert
res
!=
Status
.
SUCCESS
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'SearchVector'
)
def
test_search_vector
(
self
,
SearchVector
,
client
):
SearchVector
.
return_value
=
None
SearchVector
.
return_value
=
None
,
None
param
=
{
'table_name'
:
fake
.
table_name
(),
'query_records'
:
[
row_record_factory
()
for
_
in
range
(
1000
)],
'top_k'
:
random
.
randint
(
0
,
10
)
'query_records'
:
row_records_factory
(
256
),
'query_ranges'
:
ranges_factory
(),
'top_k'
:
random
.
randint
(
0
,
10
)
}
res
,
results
=
client
.
search_vector
(
**
param
)
assert
res
==
Status
.
OK
res
,
results
=
client
.
search_vector
s
(
**
param
)
assert
res
==
Status
.
SUCCESS
def
test_false_vector
(
self
,
client
):
param
=
{
'table_name'
:
fake
.
table_name
(),
'query_records'
:
[
row_record_factory
()
for
_
in
range
(
1000
)],
'top_k'
:
random
.
randint
(
0
,
10
)
'query_records'
:
row_records_factory
(
256
),
'query_ranges'
:
ranges_factory
(),
'top_k'
:
random
.
randint
(
0
,
10
)
}
res
,
results
=
client
.
search_vector
(
**
param
)
assert
res
!=
Status
.
OK
res
,
results
=
client
.
search_vector
s
(
**
param
)
assert
res
!=
Status
.
SUCCESS
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'DescribeTable'
)
def
test_describe_table
(
self
,
DescribeTable
,
client
):
...
...
@@ -247,27 +210,38 @@ class TestVector:
table_name
=
fake
.
table_name
()
res
,
table_schema
=
client
.
describe_table
(
table_name
)
assert
res
==
Status
.
OK
assert
isinstance
(
table_schema
,
ttypes
.
TableSchema
)
assert
res
==
Status
.
SUCCESS
assert
isinstance
(
table_schema
,
TableSchema
)
def
test_false_decribe_table
(
self
,
client
):
table_name
=
fake
.
table_name
()
res
,
table_schema
=
client
.
describe_table
(
table_name
)
assert
res
!=
Status
.
OK
assert
res
!=
Status
.
SUCCESS
assert
not
table_schema
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'ShowTables'
)
def
test_show_tables
(
self
,
ShowTables
,
client
):
ShowTables
.
return_value
=
[
fake
.
table_name
()
for
_
in
range
(
10
)]
ShowTables
.
return_value
=
[
fake
.
table_name
()
for
_
in
range
(
10
)]
,
None
res
,
tables
=
client
.
show_tables
()
assert
res
==
Status
.
OK
assert
res
==
Status
.
SUCCESS
assert
isinstance
(
tables
,
list
)
def
test_false_show_tables
(
self
,
client
):
res
,
tables
=
client
.
show_tables
()
assert
res
!=
Status
.
OK
assert
res
!=
Status
.
SUCCESS
assert
not
tables
@
mock
.
patch
.
object
(
MegasearchService
.
Client
,
'GetTableRowCount'
)
def
test_get_table_row_count
(
self
,
GetTableRowCount
,
client
):
GetTableRowCount
.
return_value
=
22
,
None
res
,
count
=
client
.
get_table_row_count
(
'fake_table'
)
assert
res
==
Status
.
SUCCESS
def
test_false_get_table_row_count
(
self
,
client
):
res
,
count
=
client
.
get_table_row_count
(
'fake_table'
)
assert
res
!=
Status
.
SUCCESS
assert
not
count
def
test_client_version
(
self
,
client
):
res
=
client
.
client_version
()
assert
res
==
'0.0.1'
...
...
@@ -275,34 +249,13 @@ class TestVector:
class
TestPrepare
:
def
test_column
(
self
):
param
=
{
'name'
:
'test01'
,
'type'
:
ColumnType
.
DATE
}
res
=
Prepare
.
column
(
**
param
)
LOGGER
.
error
(
'{}'
.
format
(
res
))
assert
res
.
name
==
'test01'
assert
res
.
type
==
ColumnType
.
DATE
assert
isinstance
(
res
,
ttypes
.
Column
)
def
test_vector_column
(
self
):
param
=
vector_column_factory
()
res
=
Prepare
.
vector_column
(
**
param
)
LOGGER
.
error
(
'{}'
.
format
(
res
))
assert
isinstance
(
res
,
ttypes
.
VectorColumn
)
def
test_table_schema
(
self
):
vec_params
=
[
vector_column_factory
()
for
i
in
range
(
10
)]
column_params
=
[
column_factory
()
for
i
in
range
(
5
)]
param
=
{
'table_name'
:
'test03'
,
'
vector_columns'
:
[
Prepare
.
vector_column
(
**
pa
)
for
pa
in
vec_params
]
,
'
attribute_columns'
:
[
Prepare
.
column
(
**
pa
)
for
pa
in
column_params
]
,
'
partition_column_names'
:
[
str
(
x
)
for
x
in
range
(
2
)]
'table_name'
:
fake
.
table_name
()
,
'
dimension'
:
random
.
randint
(
0
,
999
)
,
'
index_type'
:
IndexType
.
IDMAP
,
'
store_raw_vector'
:
False
}
res
=
Prepare
.
table_schema
(
**
param
)
assert
isinstance
(
res
,
ttypes
.
TableSchema
)
...
...
@@ -319,39 +272,10 @@ class TestPrepare:
assert
res
.
start_value
==
'200'
assert
res
.
end_value
==
'1000'
def
test_create_table_partition_param
(
self
):
param
=
{
'table_name'
:
fake
.
table_name
(),
'partition_name'
:
fake
.
table_name
(),
'column_name_to_range'
:
{
fake
.
name
():
range_factory
()
for
_
in
range
(
3
)}
}
res
=
Prepare
.
create_table_partition_param
(
**
param
)
LOGGER
.
error
(
'{}'
.
format
(
res
))
assert
isinstance
(
res
,
ttypes
.
CreateTablePartitionParam
)
def
test_delete_table_partition_param
(
self
):
param
=
{
'table_name'
:
fake
.
table_name
(),
'partition_names'
:
[
fake
.
name
()
for
i
in
range
(
5
)]
}
res
=
Prepare
.
delete_table_partition_param
(
**
param
)
assert
isinstance
(
res
,
ttypes
.
DeleteTablePartitionParam
)
def
test_row_record
(
self
):
param
=
{
'column_name_to_vector'
:
{
fake
.
name
():
[
random
.
random
()
for
i
in
range
(
256
)]},
'column_name_to_attribute'
:
{
fake
.
name
():
fake
.
name
()}
}
res
=
Prepare
.
row_record
(
**
param
)
vec
=
[
random
.
random
()
+
random
.
randint
(
0
,
9
)
for
_
in
range
(
256
)]
bin_vec
=
struct
.
pack
(
str
(
256
)
+
"d"
,
*
vec
)
res
=
Prepare
.
row_record
(
bin_vec
)
assert
isinstance
(
res
,
ttypes
.
RowRecord
)
def
test_query_record
(
self
):
param
=
{
'column_name_to_vector'
:
{
fake
.
name
():
[
random
.
random
()
for
i
in
range
(
256
)]},
'selected_columns'
:
[
fake
.
name
()
for
_
in
range
(
10
)],
'name_to_partition_ranges'
:
{
fake
.
name
():
[
range_factory
()
for
_
in
range
(
5
)]}
}
res
=
Prepare
.
query_record
(
**
param
)
assert
isinstance
(
res
,
ttypes
.
QueryRecord
)
assert
isinstance
(
bin_vec
,
bytes
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录