提交 508131e4 编写于 作者: Y Yang Xuan

feat(python): complete sdk 0.0.1


Former-commit-id: 713b3a629fcf86291d1e35c1d1b038b6fc599cb8
上级 964d3f37
from enum import IntEnum from enum import IntEnum
from .Exceptions import ConnectParamMissingError
class AbstactIndexType(object): class IndexType(IntEnum):
RAW = '1' INVALIDE = 0
IVFFLAT = '2' IDMAP = 1
IVFLAT = 2
class AbstractColumnType(object):
INVALID = 1
INT8 = 2
INT16 = 3
INT32 = 4
INT64 = 5
FLOAT32 = 6
FLOAT64 = 7
DATE = 8
VECTOR = 9
class Column(object):
"""
Table column description
:type type: ColumnType
:param type: (Required) type of the column
:type name: str
:param name: (Required) name of the column
"""
def __init__(self, name=None, type=AbstractColumnType.INVALID):
self.type = type
self.name = name
class VectorColumn(Column):
"""
Table vector column description
:type dimension: int, int64
:param dimension: (Required) vector dimension
:type index_type: string IndexType
:param index_type: (Required) IndexType
:type store_raw_vector: bool
:param store_raw_vector: (Required) Is vector self stored in the table
`Column`:
:type name: str
:param name: (Required) Name of the column
:type type: ColumnType
:param type: (Required) Default type is ColumnType.VECTOR, can't change
"""
def __init__(self, name,
dimension=0,
index_type=None,
store_raw_vector=False,
type=None):
self.dimension = dimension
self.index_type = index_type
self.store_raw_vector = store_raw_vector
super(VectorColumn, self).__init__(name, type=type)
class TableSchema(object): class TableSchema(object):
...@@ -74,30 +14,26 @@ class TableSchema(object): ...@@ -74,30 +14,26 @@ class TableSchema(object):
:type table_name: str :type table_name: str
:param table_name: (Required) name of table :param table_name: (Required) name of table
:type vector_columns: list[VectorColumn] :type index_type: IndexType
:param vector_columns: (Required) a list of VectorColumns, :param index_type: (Optional) index type, default = 0
Stores different types of vectors
:type attribute_columns: list[Column]
:param attribute_columns: (Optional) Columns description
List of `Columns` whose type isn't VECTOR `IndexType`: 0-invalid, 1-idmap, 2-ivflat
:type partition_column_names: list[str] :type dimension: int64
:param partition_column_names: (Optional) Partition column name :param dimension: (Required) dimension of vector
`Partition columns` are `attribute columns`, the number of :type store_raw_vector: bool
partition columns may be less than or equal to attribute columns, :param store_raw_vector: (Optional) default = False
this param only stores `column name`
""" """
def __init__(self, table_name, vector_columns, def __init__(self, table_name,
attribute_columns, partition_column_names, **kwargs): dimension=0,
index_type=IndexType.INVALIDE,
store_raw_vector=False):
self.table_name = table_name self.table_name = table_name
self.vector_columns = vector_columns self.index_type = index_type
self.attribute_columns = attribute_columns self.dimension = dimension
self.partition_column_names = partition_column_names self.store_raw_vector = store_raw_vector
class Range(object): class Range(object):
...@@ -105,10 +41,10 @@ class Range(object): ...@@ -105,10 +41,10 @@ class Range(object):
Range information Range information
:type start: str :type start: str
:param start: (Required) Range start value :param start: Range start value
:type end: str :type end: str
:param end: (Required) Range end value :param end: Range end value
""" """
def __init__(self, start, end): def __init__(self, start, end):
...@@ -116,97 +52,37 @@ class Range(object): ...@@ -116,97 +52,37 @@ class Range(object):
self.end = end self.end = end
class CreateTablePartitionParam(object):
"""
Create table partition parameters
:type table_name: str
:param table_name: (Required) Table name,
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
:type partition_name: str
:param partition_name: (Required) partition name, created partition name
:type column_name_to_range: dict{str : Range}
:param column_name_to_range: (Required) Column name to PartitionRange dictionary
"""
# TODO Iterable
def __init__(self, table_name, partition_name, column_name_to_range):
self.table_name = table_name
self.partition_name = partition_name
self.column_name_to_range = column_name_to_range
class DeleteTablePartitionParam(object):
"""
Delete table partition parameters
:type table_name: str
:param table_name: (Required) Table name
:type partition_names: iterable, str
:param partition_names: (Required) Partition name array
"""
# TODO Iterable
def __init__(self, table_name, partition_names):
self.table_name = table_name
self.partition_names = partition_names
class RowRecord(object): class RowRecord(object):
""" """
Record inserted Record inserted
:type column_name_to_vector: dict{str : list[float]} :type vector_data: binary str
:param column_name_to_vector: (Required) Column name to vector map :param vector_data: (Required) a vector
:type column_name_to_attribute: dict{str: str}
:param column_name_to_attribute: (Optional) Other attribute columns
""" """
def __init__(self, column_name_to_vector, column_name_to_attribute): def __init__(self, vector_data):
self.column_name_to_vector = column_name_to_vector self.vector_data = vector_data
self.column_name_to_attribute = column_name_to_attribute
class QueryRecord(object):
"""
Query record
:type column_name_to_vector: (Required) dict{str : list[float]}
:param column_name_to_vector: Query vectors, column name to vector map
:type selected_columns: list[str]
:param selected_columns: (Optional) Output column array
:type name_to_partition_ranges: dict{str : list[Range]}
:param name_to_partition_ranges: (Optional) Range used to select partitions
"""
def __init__(self, column_name_to_vector, selected_columns, name_to_partition_ranges):
self.column_name_to_vector = column_name_to_vector
self.selected_columns = selected_columns
self.name_to_partition_ranges = name_to_partition_ranges
class QueryResult(object): class QueryResult(object):
""" """
Query result Query result
:type id: int :type id: int64
:param id: Output result :param id: id of the vector
:type score: float :type score: float
:param score: Vector similarity 0 <= score <= 100 :param score: Vector similarity 0 <= score <= 100
:type column_name_to_attribute: dict{str : str}
:param column_name_to_attribute: Other columns
""" """
def __init__(self, id, score, column_name_to_attribute): def __init__(self, id, score):
self.id = id self.id = id
self.score = score self.score = score
self.column_name_to_value = column_name_to_attribute
def __repr__(self):
L = ['%s=%r' % (key, value)
for key, value in self.__dict__.items()]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
class TopKQueryResult(object): class TopKQueryResult(object):
...@@ -220,6 +96,12 @@ class TopKQueryResult(object): ...@@ -220,6 +96,12 @@ class TopKQueryResult(object):
def __init__(self, query_results): def __init__(self, query_results):
self.query_results = query_results self.query_results = query_results
def __repr__(self):
L = ['%s=%r' % (key, value)
for key, value in self.__dict__.items()]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def _abstract(): def _abstract():
raise NotImplementedError('You need to override this function') raise NotImplementedError('You need to override this function')
...@@ -232,114 +114,71 @@ class ConnectIntf(object): ...@@ -232,114 +114,71 @@ class ConnectIntf(object):
""" """
@staticmethod def connect(self, host=None, port=None, uri=None):
def create():
"""Create a connection instance and return it
should be implemented
:return connection: Connection
"""
_abstract()
@staticmethod
def destroy(connection):
"""Destroy the connection instance
should be implemented
:type connection: Connection
:param connection: The connection instance to be destroyed
:return bool, return True if destroy is successful
"""
_abstract()
def connect(self, param=None, uri=None):
""" """
Connect method should be called before any operations Connect method should be called before any operations
Server will be connected after connect return OK Server will be connected after connect return OK
should be implemented Should be implemented
:type host: str
:param host: host
:type param: ConnectParam :type port: str
:param param: ConnectParam :param port: port
:type uri: str :type uri: str
:param uri: uri param :param uri: (Optional) uri
:return: Status, indicate if connect is successful :return Status, indicate if connect is successful
""" """
if (not param and not uri) or (param and uri):
raise ConnectParamMissingError('You need to parse exact one param')
_abstract() _abstract()
def connected(self): def connected(self):
""" """
connected, connection status connected, connection status
should be implemented Should be implemented
:return: Status, indicate if connect is successful :return Status, indicate if connect is successful
""" """
_abstract() _abstract()
def disconnect(self): def disconnect(self):
""" """
Disconnect, server will be disconnected after disconnect return OK Disconnect, server will be disconnected after disconnect return SUCCESS
should be implemented Should be implemented
:return: Status, indicate if connect is successful :return Status, indicate if connect is successful
""" """
_abstract() _abstract()
def create_table(self, param): def create_table(self, param):
""" """
Create table Create table
should be implemented Should be implemented
:type param: TableSchema :type param: TableSchema
:param param: provide table information to be created :param param: provide table information to be created
:return: Status, indicate if connect is successful :return Status, indicate if connect is successful
""" """
_abstract() _abstract()
def delete_table(self, table_name): def delete_table(self, table_name):
""" """
Delete table Delete table
should be implemented Should be implemented
:type table_name: str :type table_name: str
:param table_name: table_name of the deleting table :param table_name: table_name of the deleting table
:return: Status, indicate if connect is successful :return Status, indicate if connect is successful
""" """
_abstract() _abstract()
def create_table_partition(self, param): def add_vectors(self, table_name, records):
"""
Create table partition
should be implemented
:type param: CreateTablePartitionParam
:param param: provide partition information
:return: Status, indicate if table partition is created successfully
"""
_abstract()
def delete_table_partition(self, param):
"""
Delete table partition
should be implemented
:type param: DeleteTablePartitionParam
:param param: provide partition information to be deleted
:return: Status, indicate if partition is deleted successfully
"""
_abstract()
def add_vector(self, table_name, records):
""" """
Add vectors to table Add vectors to table
should be implemented Should be implemented
:type table_name: str :type table_name: str
:param table_name: table name been inserted :param table_name: table name been inserted
...@@ -347,27 +186,31 @@ class ConnectIntf(object): ...@@ -347,27 +186,31 @@ class ConnectIntf(object):
:type records: list[RowRecord] :type records: list[RowRecord]
:param records: list of vectors been inserted :param records: list of vectors been inserted
:returns: :returns
Status : indicate if vectors inserted successfully Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id ids :list of id, after inserted every vector is given a id
""" """
_abstract() _abstract()
def search_vector(self, table_name, query_records, top_k): def search_vectors(self, table_name, query_records, query_ranges, top_k):
""" """
Query vectors in a table Query vectors in a table
should be implemented Should be implemented
:type table_name: str :type table_name: str
:param table_name: table name been queried :param table_name: table name been queried
:type query_records: list[QueryRecord] :type query_records: list[RowRecord]
:param query_records: all vectors going to be queried :param query_records: all vectors going to be queried
:type query_ranges: list[Range]
:param query_ranges: Optional ranges for conditional search.
If not specified, search whole table
:type top_k: int :type top_k: int
:param top_k: how many similar vectors will be searched :param top_k: how many similar vectors will be searched
:returns: :returns
Status: indicate if query is successful Status: indicate if query is successful
query_results: list[TopKQueryResult] query_results: list[TopKQueryResult]
""" """
...@@ -376,23 +219,37 @@ class ConnectIntf(object): ...@@ -376,23 +219,37 @@ class ConnectIntf(object):
def describe_table(self, table_name): def describe_table(self, table_name):
""" """
Show table information Show table information
should be implemented Should be implemented
:type table_name: str :type table_name: str
:param table_name: which table to be shown :param table_name: which table to be shown
:returns: :returns
Status: indicate if query is successful Status: indicate if query is successful
table_schema: TableSchema, given when operation is successful table_schema: TableSchema, given when operation is successful
""" """
_abstract() _abstract()
def get_table_row_count(self, table_name):
"""
Get table row count
Should be implemented
:type table_name, str
:param table_name, target table name.
:returns
Status: indicate if operation is successful
count: int, table row count
"""
_abstract()
def show_tables(self): def show_tables(self):
""" """
Show all tables in database Show all tables in database
should be implemented should be implemented
:return: :return
Status: indicate if this operation is successful Status: indicate if this operation is successful
tables: list[str], list of table names tables: list[str], list of table names
""" """
...@@ -403,31 +260,28 @@ class ConnectIntf(object): ...@@ -403,31 +260,28 @@ class ConnectIntf(object):
Provide client version Provide client version
should be implemented should be implemented
:return: Client version :return: str, client version
""" """
_abstract() _abstract()
pass
def server_version(self): def server_version(self):
""" """
Provide server version Provide server version
should be implemented should be implemented
:return: Server version :return: str, server version
""" """
_abstract()
def server_status(self, cmd): def server_status(self, cmd):
""" """
Provide server status Provide server status
should be implemented should be implemented
# TODO What is cmd :type cmd, str
:type cmd
:param cmd
:return: Server status :return: str, server status
""" """
_abstract() _abstract()
pass
......
此差异已折叠。
...@@ -2,10 +2,6 @@ class ParamError(ValueError): ...@@ -2,10 +2,6 @@ class ParamError(ValueError):
pass pass
class ConnectParamMissingError(ParamError):
pass
class ConnectError(ValueError): class ConnectError(ValueError):
pass pass
......
...@@ -3,13 +3,15 @@ class Status(object): ...@@ -3,13 +3,15 @@ class Status(object):
:attribute code : int (optional) default as ok :attribute code : int (optional) default as ok
:attribute message : str (optional) current status message :attribute message : str (optional) current status message
""" """
OK = 0 SUCCESS = 0
INVALID = 1 CONNECT_FAILED = 1
UNKNOWN_ERROR = 2 PERMISSION_DENIED = 2
NOT_SUPPORTED = 3 TABLE_NOT_EXISTS = 3
NOT_CONNECTED = 4 ILLEGAL_ARGUMENT = 4
ILLEGAL_RANGE = 5
ILLEGAL_DIMENSION = 6
def __init__(self, code=OK, message=None): def __init__(self, code=SUCCESS, message=None):
self.code = code self.code = code
self.message = message self.message = message
......
from client.Client import MegaSearch, Prepare, IndexType, ColumnType from client.Client import MegaSearch, Prepare, IndexType
from client.Status import Status from client.Status import Status
import time import time
import random
import struct
from pprint import pprint
from megasearch.thrift import MegasearchService, ttypes from megasearch.thrift import MegasearchService, ttypes
def main(): def main():
# Get client version
mega = MegaSearch() mega = MegaSearch()
print(mega.client_version()) print('# Client version: {}'.format(mega.client_version()))
# Connect # Connect
param = {'host': '192.168.1.129', 'port': '33001'} param = {'host': '192.168.1.129', 'port': '33001'}
cnn_status = mega.connect(**param) cnn_status = mega.connect(**param)
print('Connect Status: {}'.format(cnn_status)) print('# Connect Status: {}'.format(cnn_status))
# Check if connected
is_connected = mega.connected is_connected = mega.connected
print('Connect status: {}'.format(is_connected)) print('# Is connected: {}'.format(is_connected))
# Create table with 1 vector column, 1 attribute column and 1 partition column
# 1. prepare table_schema
# table_schema = Prepare.table_schema(
# table_name='fake_table_name' + time.strftime('%H%M%S'),
#
# vector_columns=[Prepare.vector_column(
# name='fake_vector_name' + time.strftime('%H%M%S'),
# store_raw_vector=False,
# dimension=256)],
#
# attribute_columns=[],
#
# partition_column_names=[]
# )
# get server version
print(mega.server_status('version'))
print(mega.client.Ping('version'))
# show tables and their description
statu, tables = mega.show_tables()
print(tables)
for table in tables:
s,t = mega.describe_table(table)
print('table: {}'.format(t))
# Create table # Get server version
# 1. create table schema print('# Server version: {}'.format(mega.server_version()))
table_schema_full = MegasearchService.TableSchema(
table_name='fake' + time.strftime('%H%M%S'),
vector_column_array=[MegasearchService.VectorColumn(
base=MegasearchService.Column(
name='111',
type=ttypes.TType.LIST
),
index_type="aaa",
dimension=256,
store_raw_vector=False,
)],
attribute_column_array=[],
partition_column_name_array=[] # Show tables and their description
) status, tables = mega.show_tables()
print('# Show tables: {}'.format(tables))
# 2. Create Table # Create table
create_status = mega.client.CreateTable(param=table_schema_full) # 01.Prepare data
print('Create table status: {}'.format(create_status)) param = {
'table_name': 'test'+ str(random.randint(0,999)),
# add_vector 'dimension': 256,
'index_type': IndexType.IDMAP,
'store_raw_vector': False
}
# 02.Create table
res_status = mega.create_table(Prepare.table_schema(**param))
print('# Create table status: {}'.format(res_status))
# Describe table
table_name = 'test01'
res_status, table = mega.describe_table(table_name)
print('# Describe table status: {}'.format(res_status))
print('# Describe table:{}'.format(table))
# Add vectors to table 'test01'
# 01. Prepare data
dim = 256
# list of binary vectors
vectors = [Prepare.row_record(struct.pack(str(dim)+'d',
*[random.random()for _ in range(dim)]))
for _ in range(20)]
# 02. Add vectors
status, ids = mega.add_vectors(table_name=table_name, records=vectors)
print('# Add vector status: {}'.format(status))
pprint(ids)
# Search vectors
q_records = [Prepare.row_record(struct.pack(str(dim) + 'd',
*[random.random() for _ in range(dim)]))
for _ in range(5)]
param = {
'table_name': 'test01',
'query_records': q_records,
'top_k': 10,
# 'query_ranges': None # Optional
}
sta, results = mega.search_vectors(**param)
print('# Search vectors status: {}'.format(sta))
pprint(results)
# Get table row count
sta, result = mega.get_table_row_count(table_name)
print('# Status: {}'.format(sta))
print('# Count: {}'.format(result))
# Delete table 'test01'
res_status = mega.delete_table(table_name)
print('# Delete table status: {}'.format(res_status))
# Disconnect # Disconnect
discnn_status = mega.disconnect() discnn_status = mega.disconnect()
print('Disconnect Status{}'.format(discnn_status)) print('# Disconnect Status: {}'.format(discnn_status))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
...@@ -3,17 +3,20 @@ import pytest ...@@ -3,17 +3,20 @@ import pytest
import mock import mock
import faker import faker
import random import random
import struct
from faker.providers import BaseProvider from faker.providers import BaseProvider
from client.Client import MegaSearch, Prepare, IndexType, ColumnType from client.Client import MegaSearch, Prepare
from client.Abstract import IndexType, TableSchema
from client.Status import Status from client.Status import Status
from client.Exceptions import ( from client.Exceptions import (
RepeatingConnectError, RepeatingConnectError,
DisconnectNotConnectedClientError DisconnectNotConnectedClientError
) )
from megasearch.thrift import ttypes, MegasearchService
from thrift.transport.TSocket import TSocket from thrift.transport.TSocket import TSocket
from megasearch.thrift import ttypes, MegasearchService from thrift.transport import TTransport
from thrift.transport.TTransport import TTransportException from thrift.transport.TTransport import TTransportException
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
...@@ -35,63 +38,37 @@ fake = faker.Faker() ...@@ -35,63 +38,37 @@ fake = faker.Faker()
fake.add_provider(FakerProvider) fake.add_provider(FakerProvider)
def vector_column_factory():
return {
'name': fake.name(),
'dimension': fake.dim(),
'store_raw_vector': True
}
def column_factory():
return {
'name': fake.table_name(),
'type': ColumnType.INT32
}
def range_factory(): def range_factory():
return { param = {
'start': str(random.randint(1, 10)), 'start': str(random.randint(1, 10)),
'end': str(random.randint(11, 20)), 'end': str(random.randint(11, 20)),
} }
return Prepare.range(**param)
def ranges_factory():
return [range_factory() for _ in range(5)]
def table_schema_factory(): def table_schema_factory():
vec_params = [vector_column_factory() for i in range(10)]
column_params = [column_factory() for i in range(5)]
param = { param = {
'table_name': fake.table_name(), 'table_name': fake.table_name(),
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params], 'dimension': random.randint(0, 999),
'attribute_columns': [Prepare.column(**pa) for pa in column_params], 'index_type': IndexType.IDMAP,
'partition_column_names': [str(x) for x in range(2)] 'store_raw_vector': False
} }
return Prepare.table_schema(**param) return Prepare.table_schema(**param)
def create_table_partition_param_factory(): def row_record_factory(dimension):
param = { vec = [random.random() + random.randint(0,9) for _ in range(dimension)]
'table_name': fake.table_name(), bin_vec = struct.pack(str(dimension) + "d", *vec)
'partition_name': fake.table_name(),
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
}
return Prepare.create_table_partition_param(**param)
return Prepare.row_record(vector_data=bin_vec)
def delete_table_partition_param_factory():
param = {
'table_name': fake.table_name(),
'partition_names': [fake.name() for i in range(5)]
}
return Prepare.delete_table_partition_param(**param)
def row_record_factory(): def row_records_factory(dimension):
param = { return [row_record_factory(dimension) for _ in range(20)]
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'column_name_to_attribute': {fake.name(): fake.name()}
}
return Prepare.row_record(**param)
class TestConnection: class TestConnection:
...@@ -103,9 +80,8 @@ class TestConnection: ...@@ -103,9 +80,8 @@ class TestConnection:
cnn = MegaSearch() cnn = MegaSearch()
cnn.connect(**self.param) cnn.connect(**self.param)
assert cnn.status == Status.OK assert cnn.status == Status.SUCCESS
assert cnn.connected assert cnn.connected
assert isinstance(cnn.client, MegasearchService.Client)
with pytest.raises(RepeatingConnectError): with pytest.raises(RepeatingConnectError):
cnn.connect(**self.param) cnn.connect(**self.param)
...@@ -114,12 +90,23 @@ class TestConnection: ...@@ -114,12 +90,23 @@ class TestConnection:
def test_false_connect(self): def test_false_connect(self):
cnn = MegaSearch() cnn = MegaSearch()
cnn.connect(self.param) cnn.connect(**self.param)
assert cnn.status != Status.OK assert cnn.status != Status.SUCCESS
@mock.patch.object(TTransport.TBufferedTransport, 'close')
@mock.patch.object(TSocket, 'open')
def test_disconnected(self, close, open):
close.return_value = None
open.return_value = None
cnn = MegaSearch()
cnn.connect(**self.param)
assert cnn.disconnect() == Status.SUCCESS
def test_disconnected_error(self): def test_disconnected_error(self):
cnn = MegaSearch() cnn = MegaSearch()
cnn.connect_status = Status(Status.INVALID) cnn.connect_status = Status(Status.PERMISSION_DENIED)
with pytest.raises(DisconnectNotConnectedClientError): with pytest.raises(DisconnectNotConnectedClientError):
cnn.disconnect() cnn.disconnect()
...@@ -142,26 +129,26 @@ class TestTable: ...@@ -142,26 +129,26 @@ class TestTable:
param = table_schema_factory() param = table_schema_factory()
res = client.create_table(param) res = client.create_table(param)
assert res == Status.OK assert res == Status.SUCCESS
def test_false_create_table(self, client): def test_false_create_table(self, client):
param = table_schema_factory() param = table_schema_factory()
with pytest.raises(TTransportException): with pytest.raises(TTransportException):
res = client.create_table(param) res = client.create_table(param)
LOGGER.error('{}'.format(res)) LOGGER.error('{}'.format(res))
assert res != Status.OK assert res != Status.SUCCESS
@mock.patch.object(MegasearchService.Client, 'DeleteTable') @mock.patch.object(MegasearchService.Client, 'DeleteTable')
def test_delete_table(self, DeleteTable, client): def test_delete_table(self, DeleteTable, client):
DeleteTable.return_value = None DeleteTable.return_value = None
table_name = 'fake_table_name' table_name = 'fake_table_name'
res = client.delete_table(table_name) res = client.delete_table(table_name)
assert res == Status.OK assert res == Status.SUCCESS
def test_false_delete_table(self, client): def test_false_delete_table(self, client):
table_name = 'fake_table_name' table_name = 'fake_table_name'
res = client.delete_table(table_name) res = client.delete_table(table_name)
assert res != Status.OK assert res != Status.SUCCESS
class TestVector: class TestVector:
...@@ -176,70 +163,46 @@ class TestVector: ...@@ -176,70 +163,46 @@ class TestVector:
cnn.connect(**param) cnn.connect(**param)
return cnn return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTablePartition')
def test_create_table_partition(self, CreateTablePartition, client):
CreateTablePartition.return_value = None
param = create_table_partition_param_factory()
res = client.create_table_partition(param)
assert res == Status.OK
def test_false_table_partition(self, client):
param = create_table_partition_param_factory()
res = client.create_table_partition(param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DeleteTablePartition')
def test_delete_table_partition(self, DeleteTablePartition, client):
DeleteTablePartition.return_value = None
param = delete_table_partition_param_factory()
res = client.delete_table_partition(param)
assert res == Status.OK
def test_false_delete_table_partition(self, client):
param = delete_table_partition_param_factory()
res = client.delete_table_partition(param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'AddVector') @mock.patch.object(MegasearchService.Client, 'AddVector')
def test_add_vector(self, AddVector, client): def test_add_vector(self, AddVector, client):
AddVector.return_value = None AddVector.return_value = None
param ={ param ={
'table_name': fake.table_name(), 'table_name': fake.table_name(),
'records': [row_record_factory() for _ in range(1000)] 'records': row_records_factory(256)
} }
res, ids = client.add_vector(**param) res, ids = client.add_vectors(**param)
assert res == Status.OK assert res == Status.SUCCESS
def test_false_add_vector(self, client): def test_false_add_vector(self, client):
param ={ param ={
'table_name': fake.table_name(), 'table_name': fake.table_name(),
'records': [row_record_factory() for _ in range(1000)] 'records': row_records_factory(256)
} }
res, ids = client.add_vector(**param) res, ids = client.add_vectors(**param)
assert res != Status.OK assert res != Status.SUCCESS
@mock.patch.object(MegasearchService.Client, 'SearchVector') @mock.patch.object(MegasearchService.Client, 'SearchVector')
def test_search_vector(self, SearchVector, client): def test_search_vector(self, SearchVector, client):
SearchVector.return_value = None SearchVector.return_value = None, None
param = { param = {
'table_name': fake.table_name(), 'table_name': fake.table_name(),
'query_records': [row_record_factory() for _ in range(1000)], 'query_records': row_records_factory(256),
'top_k': random.randint(0,10) 'query_ranges': ranges_factory(),
'top_k': random.randint(0, 10)
} }
res, results = client.search_vector(**param) res, results = client.search_vectors(**param)
assert res == Status.OK assert res == Status.SUCCESS
def test_false_vector(self, client): def test_false_vector(self, client):
param = { param = {
'table_name': fake.table_name(), 'table_name': fake.table_name(),
'query_records': [row_record_factory() for _ in range(1000)], 'query_records': row_records_factory(256),
'top_k': random.randint(0,10) 'query_ranges': ranges_factory(),
'top_k': random.randint(0, 10)
} }
res, results = client.search_vector(**param) res, results = client.search_vectors(**param)
assert res != Status.OK assert res != Status.SUCCESS
@mock.patch.object(MegasearchService.Client, 'DescribeTable') @mock.patch.object(MegasearchService.Client, 'DescribeTable')
def test_describe_table(self, DescribeTable, client): def test_describe_table(self, DescribeTable, client):
...@@ -247,27 +210,38 @@ class TestVector: ...@@ -247,27 +210,38 @@ class TestVector:
table_name = fake.table_name() table_name = fake.table_name()
res, table_schema = client.describe_table(table_name) res, table_schema = client.describe_table(table_name)
assert res == Status.OK assert res == Status.SUCCESS
assert isinstance(table_schema, ttypes.TableSchema) assert isinstance(table_schema, TableSchema)
def test_false_decribe_table(self, client): def test_false_decribe_table(self, client):
table_name = fake.table_name() table_name = fake.table_name()
res, table_schema = client.describe_table(table_name) res, table_schema = client.describe_table(table_name)
assert res != Status.OK assert res != Status.SUCCESS
assert not table_schema assert not table_schema
@mock.patch.object(MegasearchService.Client, 'ShowTables') @mock.patch.object(MegasearchService.Client, 'ShowTables')
def test_show_tables(self, ShowTables, client): def test_show_tables(self, ShowTables, client):
ShowTables.return_value = [fake.table_name() for _ in range(10)] ShowTables.return_value = [fake.table_name() for _ in range(10)], None
res, tables = client.show_tables() res, tables = client.show_tables()
assert res == Status.OK assert res == Status.SUCCESS
assert isinstance(tables, list) assert isinstance(tables, list)
def test_false_show_tables(self, client): def test_false_show_tables(self, client):
res, tables = client.show_tables() res, tables = client.show_tables()
assert res != Status.OK assert res != Status.SUCCESS
assert not tables assert not tables
@mock.patch.object(MegasearchService.Client, 'GetTableRowCount')
def test_get_table_row_count(self, GetTableRowCount, client):
GetTableRowCount.return_value = 22, None
res, count = client.get_table_row_count('fake_table')
assert res == Status.SUCCESS
def test_false_get_table_row_count(self, client):
res,count = client.get_table_row_count('fake_table')
assert res != Status.SUCCESS
assert not count
def test_client_version(self, client): def test_client_version(self, client):
res = client.client_version() res = client.client_version()
assert res == '0.0.1' assert res == '0.0.1'
...@@ -275,34 +249,13 @@ class TestVector: ...@@ -275,34 +249,13 @@ class TestVector:
class TestPrepare: class TestPrepare:
def test_column(self):
param = {
'name': 'test01',
'type': ColumnType.DATE
}
res = Prepare.column(**param)
LOGGER.error('{}'.format(res))
assert res.name == 'test01'
assert res.type == ColumnType.DATE
assert isinstance(res, ttypes.Column)
def test_vector_column(self):
param = vector_column_factory()
res = Prepare.vector_column(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.VectorColumn)
def test_table_schema(self): def test_table_schema(self):
vec_params = [vector_column_factory() for i in range(10)]
column_params = [column_factory() for i in range(5)]
param = { param = {
'table_name': 'test03', 'table_name': fake.table_name(),
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params], 'dimension': random.randint(0, 999),
'attribute_columns': [Prepare.column(**pa) for pa in column_params], 'index_type': IndexType.IDMAP,
'partition_column_names': [str(x) for x in range(2)] 'store_raw_vector': False
} }
res = Prepare.table_schema(**param) res = Prepare.table_schema(**param)
assert isinstance(res, ttypes.TableSchema) assert isinstance(res, ttypes.TableSchema)
...@@ -319,39 +272,10 @@ class TestPrepare: ...@@ -319,39 +272,10 @@ class TestPrepare:
assert res.start_value == '200' assert res.start_value == '200'
assert res.end_value == '1000' assert res.end_value == '1000'
def test_create_table_partition_param(self):
param = {
'table_name': fake.table_name(),
'partition_name': fake.table_name(),
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
}
res = Prepare.create_table_partition_param(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.CreateTablePartitionParam)
def test_delete_table_partition_param(self):
param = {
'table_name': fake.table_name(),
'partition_names': [fake.name() for i in range(5)]
}
res = Prepare.delete_table_partition_param(**param)
assert isinstance(res, ttypes.DeleteTablePartitionParam)
def test_row_record(self): def test_row_record(self):
param={ vec = [random.random() + random.randint(0, 9) for _ in range(256)]
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, bin_vec = struct.pack(str(256) + "d", *vec)
'column_name_to_attribute': {fake.name(): fake.name()} res = Prepare.row_record(bin_vec)
}
res = Prepare.row_record(**param)
assert isinstance(res, ttypes.RowRecord) assert isinstance(res, ttypes.RowRecord)
assert isinstance(bin_vec, bytes)
def test_query_record(self):
param = {
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'selected_columns': [fake.name() for _ in range(10)],
'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]}
}
res = Prepare.query_record(**param)
assert isinstance(res, ttypes.QueryRecord)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册