提交 ff662a9b 编写于 作者: Y Yang Xuan

feat(python): impl sdk, add unittests


Former-commit-id: b617e88e024bc925d45f17e3d3b0a77c6c08a12c
上级 0d7b464c
from enum import IntEnum
from .Exceptions import ConnectParamMissingError
class AbstactIndexType(object):
RAW = 1
IVFFLAT = 2
class AbstractColumnType(object):
INVALID = 1
INT8 = 2
INT16 = 3
INT32 = 4
INT64 = 5
FLOAT32 = 6
FLOAT64 = 7
DATE = 8
VECTOR = 9
class Column(object):
"""
Table column description
:type type: ColumnType
:param type: type of the column
:type name: str
:param name: name of the column
"""
def __init__(self, name=None, type=AbstractColumnType.INVALID):
self.type = type
self.name = name
class VectorColumn(Column):
"""
Table vector column description
:type dimension: int, int64
:param dimension: vector dimension
:type index_type: IndexType
:param index_type: IndexType
:type store_raw_vector: bool
:param store_raw_vector: Is vector self stored in the table
`Column`:
:type name: str
:param name: Name of the column
:type type: ColumnType
:param type: Default type is ColumnType.VECTOR, can't change
"""
def __init__(self, name,
dimension=0,
index_type=AbstactIndexType.RAW,
store_raw_vector=False):
self.dimension = dimension
self.index_type = index_type
self.store_raw_vector = store_raw_vector
super(VectorColumn, self).__init__(name, type=AbstractColumnType.VECTOR)
class TableSchema(object):
"""
Table Schema
:type table_name: str
:param table_name: name of table
:type vector_columns: list[VectorColumn]
:param vector_columns: a list of VectorColumns,
Stores different types of vectors
:type attribute_columns: list[Column]
:param attribute_columns: Columns description
List of `Columns` whose type isn't VECTOR
:type partition_column_names: list[str]
:param partition_column_names: Partition column name
`Partition columns` are `attribute columns`, the number of
partition columns may be less than or equal to attribute columns,
this param only stores `column name`
"""
def __init__(self, table_name, vector_columns,
attribute_columns, partition_column_names, **kwargs):
self.table_name = table_name
self.vector_columns = vector_columns
self.attribute_columns = attribute_columns
self.partition_column_names = partition_column_names
class Range(object):
"""
Range information
:type start: str
:param start: Range start value
:type end: str
:param end: Range end value
"""
def __init__(self, start, end):
self.start = start
self.end = end
class CreateTablePartitionParam(object):
"""
Create table partition parameters
:type table_name: str
:param table_name: Table name,
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
:type partition_name: str
:param partition_name: partition name, created partition name
:type column_name_to_range: dict{str : Range}
:param column_name_to_range: Column name to PartitionRange dictionary
"""
# TODO Iterable
def __init__(self, table_name, partition_name, column_name_to_range):
self.table_name = table_name
self.partition_name = partition_name
self.column_name_to_range = column_name_to_range
class DeleteTablePartitionParam(object):
"""
Delete table partition parameters
:type table_name: str
:param table_name: Table name
:type partition_names: iterable, str
:param partition_names: Partition name array
"""
# TODO Iterable
def __init__(self, table_name, partition_names):
self.table_name = table_name
self.partition_names = partition_names
class RowRecord(object):
"""
Record inserted
:type column_name_to_vector: dict{str : list[float]}
:param column_name_to_vector: Column name to vector map
:type column_name_to_attribute: dict{str: str}
:param column_name_to_attribute: Other attribute columns
"""
def __init__(self, column_name_to_vector, column_name_to_attribute):
self.column_name_to_vector = column_name_to_vector
self.column_name_to_attribute = column_name_to_attribute
class QueryRecord(object):
"""
Query record
:type column_name_to_vector: dict{str : list[float]}
:param column_name_to_vector: Query vectors, column name to vector map
:type selected_columns: list[str]
:param selected_columns: Output column array
:type name_to_partition_ranges: dict{str : list[Range]}
:param name_to_partition_ranges: Range used to select partitions
"""
def __init__(self, column_name_to_vector, selected_columns, name_to_partition_ranges):
self.column_name_to_vector = column_name_to_vector
self.selected_columns = selected_columns
self.name_to_partition_ranges = name_to_partition_ranges
class QueryResult(object):
"""
Query result
:type id: int
:param id: Output result
:type score: float
:param score: Vector similarity 0 <= score <= 100
:type column_name_to_attribute: dict{str : str}
:param column_name_to_attribute: Other columns
"""
def __init__(self, id, score, column_name_to_attribute):
self.id = id
self.score = score
self.column_name_to_value = column_name_to_attribute
class TopKQueryResult(object):
"""
TopK query results
:type query_results: list[QueryResult]
:param query_results: TopK query results
"""
def __init__(self, query_results):
self.query_results = query_results
def _abstract():
raise NotImplementedError('You need to override this function')
class ConnectIntf(object):
"""SDK client abstract class
Connection is a abstract class
"""
@staticmethod
def create():
"""Create a connection instance and return it
should be implemented
:return connection: Connection
"""
_abstract()
@staticmethod
def destroy(connection):
"""Destroy the connection instance
should be implemented
:type connection: Connection
:param connection: The connection instance to be destroyed
:return bool, return True if destroy is successful
"""
_abstract()
def connect(self, param=None, uri=None):
"""
Connect method should be called before any operations
Server will be connected after connect return OK
should be implemented
:type param: ConnectParam
:param param: ConnectParam
:type uri: str
:param uri: uri param
:return: Status, indicate if connect is successful
"""
if (not param and not uri) or (param and uri):
raise ConnectParamMissingError('You need to parse exact one param')
_abstract()
def connected(self):
"""
connected, connection status
should be implemented
:return: Status, indicate if connect is successful
"""
_abstract()
def disconnect(self):
"""
Disconnect, server will be disconnected after disconnect return OK
should be implemented
:return: Status, indicate if connect is successful
"""
_abstract()
def create_table(self, param):
"""
Create table
should be implemented
:type param: TableSchema
:param param: provide table information to be created
:return: Status, indicate if connect is successful
"""
_abstract()
def delete_table(self, table_name):
"""
Delete table
should be implemented
:type table_name: str
:param table_name: table_name of the deleting table
:return: Status, indicate if connect is successful
"""
_abstract()
def create_table_partition(self, param):
"""
Create table partition
should be implemented
:type param: CreateTablePartitionParam
:param param: provide partition information
:return: Status, indicate if table partition is created successfully
"""
_abstract()
def delete_table_partition(self, param):
"""
Delete table partition
should be implemented
:type param: DeleteTablePartitionParam
:param param: provide partition information to be deleted
:return: Status, indicate if partition is deleted successfully
"""
_abstract()
def add_vector(self, table_name, records):
"""
Add vectors to table
should be implemented
:type table_name: str
:param table_name: table name been inserted
:type records: list[RowRecord]
:param records: list of vectors been inserted
:returns:
Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id
"""
_abstract()
def search_vector(self, table_name, query_records, top_k):
"""
Query vectors in a table
should be implemented
:type table_name: str
:param table_name: table name been queried
:type query_records: list[QueryRecord]
:param query_records: all vectors going to be queried
:type top_k: int
:param top_k: how many similar vectors will be searched
:returns:
Status: indicate if query is successful
query_results: list[TopKQueryResult]
"""
_abstract()
def describe_table(self, table_name):
"""
Show table information
should be implemented
:type table_name: str
:param table_name: which table to be shown
:returns:
Status: indicate if query is successful
table_schema: TableSchema, given when operation is successful
"""
_abstract()
def show_tables(self):
"""
Show all tables in database
should be implemented
:return:
Status: indicate if this operation is successful
tables: list[str], list of table names
"""
_abstract()
def client_version(self):
"""
Provide client version
should be implemented
:return: Client version
"""
_abstract()
pass
def server_version(self):
"""
Provide server version
should be implemented
:return: Server version
"""
def server_status(self, cmd):
"""
Provide server status
should be implemented
# TODO What is cmd
:type cmd
:param cmd
:return: Server status
"""
_abstract()
pass
import logging, logging.config
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
from thrift.Thrift import TException, TApplicationException, TType
from megasearch.thrift import MegasearchService
from megasearch.thrift import ttypes
from client.Abstract import (
ConnectIntf, TableSchema,
AbstactIndexType, AbstractColumnType,
Column,
VectorColumn, Range,
CreateTablePartitionParam,
DeleteTablePartitionParam,
RowRecord, QueryRecord,
QueryResult, TopKQueryResult
)
from client.Status import Status
from client.Exceptions import (
RepeatingConnectError, ConnectParamMissingError,
DisconnectNotConnectedClientError,
ParamError, NotConnectError
)
LOGGER = logging.getLogger(__name__)
__VERSION__ = '0.0.1'
__NAME__ = 'Thrift_Client'
class IndexType(AbstactIndexType):
# TODO thrift in IndexType
RAW = 1
IVFFLAT = 2
class ColumnType(AbstractColumnType):
# INVALID = 1
# INT8 = 2
# INT16 = 3
# INT32 = 4
# INT64 = 5
FLOAT32 = 6
FLOAT64 = 7
DATE = 8
# VECTOR = 9
INVALID = TType.STOP
INT8 = TType.I08
INT16 = TType.I16
INT32 = TType.I32
INT64 = TType.I64
VECTOR = TType.LIST
class Prepare(object):
@classmethod
def column(cls, name, type):
"""
Table column param
:param type: ColumnType, type of the column
:param name: str, name of the column
:return Column
"""
# TODO type in Thrift, may have error
temp_column = Column(name=name, type=type)
return ttypes.Column(name=temp_column.name, type=temp_column.type)
@classmethod
def vector_column(cls, name, dimension,
# index_type=IndexType.RAW,
store_raw_vector=False):
"""
Table vector column description
:param dimension: int64, vector dimension
:param index_type: IndexType
:param store_raw_vector: Bool, Is vector self stored in the table
`Column`:
:param name: Name of the column
:param type: Default type is ColumnType.VECTOR, can't change
:return VectorColumn
"""
# temp = VectorColumn(name=name, dimension=dimension,
# index_type=index_type, store_raw_vector=store_raw_vector)
# return ttypes.VectorColumn(base=base, dimension=temp.dimension,
# store_raw_vector=temp.store_raw_vector,
# index_type=temp.index_type)
# Without IndexType
temp = VectorColumn(name=name, dimension=dimension,
store_raw_vector=store_raw_vector)
base = ttypes.Column(name=temp.name, type=ColumnType.VECTOR)
return ttypes.VectorColumn(base=base, dimension=temp.dimension,
store_raw_vector=temp.store_raw_vector)
@classmethod
def table_schema(cls, table_name,
vector_columns,
attribute_columns,
partition_column_names):
"""
:param table_name: Name of the table
:param vector_columns: List of VectorColumns
`VectorColumn`:
- dimension: int, default = 0
Dimension of the vector, different vector_columns'
dimension may vary
- index_type: (optional) IndexType, default=IndexType.RAW
Vector's index type
- store_raw_vector : (optional) bool, default=False
- name: str
Name of the column
- type: ColumnType, default=ColumnType.VECTOR, can't change
:param attribute_columns: List of Columns. Attribute
columns are Columns whose type aren't ColumnType.VECTOR
`Column`:
- name: str
- type: ColumnType, default=ColumnType.INVALID
:param partition_column_names: List of str.
Partition columns name
indicates which attribute columns is used for partition, can
have lots of partition columns as long as:
-> No. partition_column_names <= No. attribute_columns
-> partition_column_names IN attribute_column_names
:return: TableSchema
"""
temp = TableSchema(table_name,vector_columns,
attribute_columns,
partition_column_names)
return ttypes.TableSchema(table_name=temp.table_name,
vector_column_array=temp.vector_columns,
attribute_column_array=temp.attribute_columns,
partition_column_name_array=temp.partition_column_names)
@classmethod
def range(cls, start, end):
"""
:param start: Partition range start value
:param end: Partition range end value
:return Range
"""
temp = Range(start=start, end=end)
return ttypes.Range(start_value=temp.start, end_value=temp.end)
@classmethod
def create_table_partition_param(cls,
table_name,
partition_name,
column_name_to_range):
"""
Create table partition parameters
:param table_name: str, Table name,
VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition
:param partition_name: str partition name, created partition name
:param column_name_to_range: dict, column name to partition range dictionary
:return CreateTablePartitionParam
"""
temp = CreateTablePartitionParam(table_name=table_name,
partition_name=partition_name,
column_name_to_range=column_name_to_range)
return ttypes.CreateTablePartitionParam(table_name=temp.table_name,
partition_name=temp.partition_name,
range_map=temp.column_name_to_range)
@classmethod
def delete_table_partition_param(cls, table_name, partition_names):
"""
Delete table partition parameters
:param table_name: Table name
:param partition_names: List of partition names
:return DeleteTablePartitionParam
"""
temp = DeleteTablePartitionParam(table_name=table_name,
partition_names=partition_names)
return ttypes.DeleteTablePartitionParam(table_name=table_name,
partition_name_array=partition_names)
@classmethod
def row_record(cls, column_name_to_vector, column_name_to_attribute):
"""
:param column_name_to_vector: dict{str : list[float]}
Column name to vector map
:param column_name_to_attribute: dict{str: str}
Other attribute columns
"""
temp = RowRecord(column_name_to_vector=column_name_to_vector,
column_name_to_attribute=column_name_to_attribute)
return ttypes.RowRecord(vector_map=temp.column_name_to_vector,
attribute_map=temp.column_name_to_attribute)
@classmethod
def query_record(cls, column_name_to_vector,
selected_columns, name_to_partition_ranges):
"""
:param column_name_to_vector: dict{str : list[float]}
Query vectors, column name to vector map
:param selected_columns: list[str_column_name]
List of Output columns
:param name_to_partition_ranges: dict{str : list[Range]}
Partition Range used to search
`Range`:
:param start: Partition range start value
:param end: Partition range end value
:return QueryRecord
"""
temp = QueryRecord(column_name_to_vector=column_name_to_vector,
selected_columns=selected_columns,
name_to_partition_ranges=name_to_partition_ranges)
return ttypes.QueryRecord(vector_map=temp.column_name_to_vector,
selected_column_array=temp.selected_columns,
partition_filter_column_map=name_to_partition_ranges)
class MegaSearch(ConnectIntf):
def __init__(self):
self.transport = None
self.client = None
self.status = None
def __repr__(self):
return '{}'.format(self.status)
@staticmethod
def create():
# TODO in python, maybe this method is useless
return MegaSearch()
@staticmethod
def destroy(connection):
"""Destroy the connection instance"""
# TODO in python, maybe this method is useless
pass
def connect(self, host='localhost', port='9090', uri=None):
# TODO URI
if self.status and self.status == Status(message='Connected'):
raise RepeatingConnectError("You have already connected!")
transport = TSocket.TSocket(host=host, port=port)
self.transport = TTransport.TBufferedTransport(transport)
protocol = TJSONProtocol.TJSONProtocol(transport)
self.client = MegasearchService.Client(protocol)
try:
transport.open()
self.status = Status(Status.OK, 'Connected')
LOGGER.info('Connected!')
except (TTransport.TTransportException, TException) as e:
self.status = Status(Status.INVALID, message=str(e))
LOGGER.error('logger.error: {}'.format(self.status))
finally:
return self.status
@property
def connected(self):
return self.status == Status()
def disconnect(self):
if not self.transport:
raise DisconnectNotConnectedClientError('Error')
try:
self.transport.close()
LOGGER.info('Client Disconnected!')
self.status = None
except TException as e:
return Status(Status.INVALID, str(e))
return Status(Status.OK, 'Disconnected')
def create_table(self, param):
"""Create table
:param param: Provide table information to be created,
`Please use Prepare.table_schema generate param`
:return: Status, indicate if operation is successful
"""
if not self.client:
raise NotConnectError('Please Connect to the server first!')
try:
self.client.CreateTable(param)
except (TApplicationException, TException) as e:
LOGGER.error('Unable to create table')
return Status(Status.INVALID, str(e))
return Status(message='Table {} created!'.format(param.table_name))
def delete_table(self, table_name):
"""Delete table
:param table_name: Name of the table being deleted
:return: Status, indicate if operation is successful
"""
try:
self.client.DeleteTable(table_name)
except (TApplicationException, TException) as e:
LOGGER.error('Unable to delete table {}'.format(table_name))
return Status(Status.INVALID, str(e))
return Status(message='Table {} deleted!'.format(table_name))
def create_table_partition(self, param):
"""
Create table partition
:type param: CreateTablePartitionParam, provide partition information
`Please use Prepare.create_table_partition_param generate param`
:return: Status, indicate if table partition is created successfully
"""
try:
self.client.CreateTablePartition(param)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.INVALID, str(e))
return Status(message='Table partition created successfully!')
def delete_table_partition(self, param):
"""
Delete table partition
:type param: DeleteTablePartitionParam
:param param: provide partition information to be deleted
`Please use Prepare.delete_table_partition_param generate param`
:return: Status, indicate if partition is deleted successfully
"""
try:
self.client.DeleteTablePartition(param)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.INVALID, str(e))
return Status(message='Table partition deleted successfully!')
def add_vector(self, table_name, records):
"""
Add vectors to table
:param table_name: table name been inserted
:param records: List[RowRecord], list of vectors been inserted
`Please use Prepare.row_record generate records`
:returns:
Status : indicate if vectors inserted successfully
ids :list of id, after inserted every vector is given a id
"""
try:
ids = self.client.AddVector(table_name=table_name, record_array=records)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.INVALID, str(e)), None
return Status(message='Vector added successfully!'), ids
def search_vector(self, table_name, query_records, top_k):
"""
Query vectors in a table
:param table_name: str, table name been queried
:param query_records: list[QueryRecord], all vectors going to be queried
`Please use Prepare.query_record generate QueryRecord`
:param top_k: int, how many similar vectors will be searched
:returns:
Status: indicate if query is successful
query_results: list[TopKQueryResult], return when operation is successful
"""
# TODO topk_query_results
try:
topk_query_results = self.client.SearchVector(
table_name=table_name, query_record_array=query_records, topk=top_k)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.INVALID, str(e)), None
return Status(message='Success!'), topk_query_results
def describe_table(self, table_name):
"""
Show table information
:param table_name: str, which table to be shown
:returns:
Status: indicate if query is successful
table_schema: TableSchema, return when operation is successful
"""
try:
thrift_table_schema = self.client.DescribeTable(table_name)
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.INVALID, str(e)), None
# TODO Table Schema
return Status(message='Success!'), thrift_table_schema
def show_tables(self):
"""
Show all tables in database
:return:
Status: indicate if this operation is successful
tables: list[str], list of table names, return when operation
is successful
"""
try:
tables = self.client.ShowTables()
except (TApplicationException, TException) as e:
LOGGER.error('{}'.format(e))
return Status(Status.INVALID, str(e)), None
return Status(message='Success!'), tables
def client_version(self):
"""
Provide client version
:return: Client version
"""
return __VERSION__
def server_version(self):
"""
Provide server version
:return: Server version
"""
# TODO How to get server version
pass
def server_status(self, cmd):
"""
Provide server status
:return: Server status
"""
self.client.Ping(cmd)
pass
class ParamError(ValueError):
pass
class ConnectParamMissingError(ParamError):
pass
class ConnectError(ValueError):
pass
class NotConnectError(ConnectError):
pass
class RepeatingConnectError(ConnectError):
pass
class DisconnectNotConnectedClientError(ValueError):
pass
class Status(object):
"""
:attribute code : int (optional) default as ok
:attribute message : str (optional) current status message
"""
OK = 0
INVALID = 1
UNKNOWN_ERROR = 2
NOT_SUPPORTED = 3
NOT_CONNECTED = 4
def __init__(self, code=OK, message=None):
self.code = code
self.message = message
def __repr__(self):
L = ['%s=%r' % (key, value)
for key, value in self.__dict__.items()]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other):
"""Make Status comparable with self by code"""
if isinstance(other, int):
return self.code == other
else:
return isinstance(other, self.__class__) and self.code == other.code
def __ne__(self, other):
return not (self == other)
from megasearch.client.Client import Connection
from client.Client import MegaSearch, Prepare, IndexType, ColumnType
from client.Status import Status
client = Connection()
# param =
# client.connect(param)
\ No newline at end of file
def main():
mega = MegaSearch()
# Connect
param = {'host': '192.168.1.129', 'port': '33001'}
cnn_status = mega.connect(**param)
print('Connect Status: {}'.format(cnn_status))
is_connected = mega.connected
print('Connect status: {}'.format(is_connected))
# # Create table with 1 vector column, 1 attribute column and 1 partition column
# # 1. prepare table_schema
# vector_column = {
# 'name': 'fake_vec_name01',
# 'store_raw_vector': True,
# 'dimension': 10
# }
# attribute_column = {
# 'name': 'fake_attri_name01',
# 'type': ColumnType.DATE,
# }
#
# table = {
# 'table_name': 'fake_table_name01',
# 'vector_columns': [Prepare.vector_column(**vector_column)],
# 'attribute_columns': [Prepare.column(**attribute_column)],
# 'partition_column_names': ['fake_attri_name01']
# }
# table_schema = Prepare.table_schema(**table)
#
# # 2. Create Table
# create_status = mega.create_table(table_schema)
# print('Create table status: {}'.format(create_status))
mega.server_status('ok!')
# Disconnect
discnn_status = mega.disconnect()
print('Disconnect Status{}'.format(discnn_status))
if __name__ == '__main__':
main()
\ No newline at end of file
[pytest]
log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
log_cli = true
log_level = 30
\ No newline at end of file
import logging
import pytest
import mock
import faker
import random
from faker.providers import BaseProvider
from client.Client import MegaSearch, Prepare, IndexType, ColumnType
from client.Status import Status
from client.Exceptions import (
RepeatingConnectError,
DisconnectNotConnectedClientError
)
from thrift.transport.TSocket import TSocket
from megasearch.thrift import ttypes, MegasearchService
LOGGER = logging.getLogger(__name__)
class FakerProvider(BaseProvider):
def table_name(self):
return 'table_name' + str(random.randint(1000, 9999))
def name(self):
return 'name' + str(random.randint(1000, 9999))
def dim(self):
return random.randint(0, 999)
fake = faker.Faker()
fake.add_provider(FakerProvider)
def vector_column_factory():
return {
'name': fake.name(),
'dimension': fake.dim(),
'index_type': IndexType.IVFFLAT,
'store_raw_vector': True
}
def column_factory():
return {
'name': fake.table_name(),
'type': IndexType.RAW
}
def range_factory():
return {
'start': str(random.randint(1, 10)),
'end': str(random.randint(11, 20)),
}
def table_schema_factory():
vec_params = [vector_column_factory() for i in range(10)]
column_params = [column_factory() for i in range(5)]
param = {
'table_name': fake.table_name(),
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
'partition_column_names': [str(x) for x in range(2)]
}
return Prepare.table_schema(**param)
def create_table_partition_param_factory():
param = {
'table_name': fake.table_name(),
'partition_name': fake.table_name(),
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
}
return Prepare.create_table_partition_param(**param)
def delete_table_partition_param_factory():
param = {
'table_name': fake.table_name(),
'partition_names': [fake.name() for i in range(5)]
}
return Prepare.delete_table_partition_param(**param)
def row_record_factory():
param = {
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'column_name_to_attribute': {fake.name(): fake.name()}
}
return Prepare.row_record(**param)
class TestConnection:
param = {'host':'localhost', 'port': '5000'}
@mock.patch.object(TSocket, 'open')
def test_true_connect(self, open):
open.return_value = None
cnn = MegaSearch()
cnn.connect(**self.param)
assert cnn.status == Status.OK
assert cnn.connected
assert isinstance(cnn.client, MegasearchService.Client)
with pytest.raises(RepeatingConnectError):
cnn.connect(**self.param)
cnn.connect()
def test_false_connect(self):
cnn = MegaSearch()
cnn.connect(self.param)
assert cnn.status != Status.OK
def test_disconnected_error(self):
cnn = MegaSearch()
cnn.connect_status = Status(Status.INVALID)
with pytest.raises(DisconnectNotConnectedClientError):
cnn.disconnect()
class TestTable:
@pytest.fixture
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
cnn.connect(**param)
return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTable')
def test_create_table(self, CreateTable, client):
CreateTable.return_value = None
param = table_schema_factory()
res = client.create_table(param)
assert res == Status.OK
def test_false_create_table(self, client):
param = table_schema_factory()
res = client.create_table(param)
LOGGER.error('{}'.format(res))
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DeleteTable')
def test_delete_table(self, DeleteTable, client):
DeleteTable.return_value = None
table_name = 'fake_table_name'
res = client.delete_table(table_name)
assert res == Status.OK
def test_false_delete_table(self, client):
table_name = 'fake_table_name'
res = client.delete_table(table_name)
assert res != Status.OK
class TestVector:
@pytest.fixture
@mock.patch.object(TSocket, 'open')
def client(self, open):
param = {'host': 'localhost', 'port': '5000'}
open.return_value = None
cnn = MegaSearch()
cnn.connect(**param)
return cnn
@mock.patch.object(MegasearchService.Client, 'CreateTablePartition')
def test_create_table_partition(self, CreateTablePartition, client):
CreateTablePartition.return_value = None
param = create_table_partition_param_factory()
res = client.create_table_partition(param)
assert res == Status.OK
def test_false_table_partition(self, client):
param = create_table_partition_param_factory()
res = client.create_table_partition(param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DeleteTablePartition')
def test_delete_table_partition(self, DeleteTablePartition, client):
DeleteTablePartition.return_value = None
param = delete_table_partition_param_factory()
res = client.delete_table_partition(param)
assert res == Status.OK
def test_false_delete_table_partition(self, client):
param = delete_table_partition_param_factory()
res = client.delete_table_partition(param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'AddVector')
def test_add_vector(self, AddVector, client):
AddVector.return_value = None
param ={
'table_name': fake.table_name(),
'records': [row_record_factory() for _ in range(1000)]
}
res, ids = client.add_vector(**param)
assert res == Status.OK
def test_false_add_vector(self, client):
param ={
'table_name': fake.table_name(),
'records': [row_record_factory() for _ in range(1000)]
}
res, ids = client.add_vector(**param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'SearchVector')
def test_search_vector(self, SearchVector, client):
SearchVector.return_value = None
param = {
'table_name': fake.table_name(),
'query_records': [row_record_factory() for _ in range(1000)],
'top_k': random.randint(0,10)
}
res, results = client.search_vector(**param)
assert res == Status.OK
def test_false_vector(self, client):
param = {
'table_name': fake.table_name(),
'query_records': [row_record_factory() for _ in range(1000)],
'top_k': random.randint(0,10)
}
res, results = client.search_vector(**param)
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DescribeTable')
def test_describe_table(self, DescribeTable, client):
DescribeTable.return_value = table_schema_factory()
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
assert res == Status.OK
assert isinstance(table_schema, ttypes.TableSchema)
def test_false_decribe_table(self, client):
table_name = fake.table_name()
res, table_schema = client.describe_table(table_name)
assert res != Status.OK
assert not table_schema
@mock.patch.object(MegasearchService.Client, 'ShowTables')
def test_show_tables(self, ShowTables, client):
ShowTables.return_value = [fake.table_name() for _ in range(10)]
res, tables = client.show_tables()
assert res == Status.OK
assert isinstance(tables, list)
def test_false_show_tables(self, client):
res, tables = client.show_tables()
assert res != Status.OK
assert not tables
def test_client_version(self, client):
res = client.client_version()
assert res == '0.0.1'
class TestPrepare:
def test_column(self):
param = {
'name': 'test01',
'type': ColumnType.DATE
}
res = Prepare.column(**param)
LOGGER.error('{}'.format(res))
assert res.name == 'test01'
assert res.type == ColumnType.DATE
assert isinstance(res, ttypes.Column)
def test_vector_column(self):
param = vector_column_factory()
res = Prepare.vector_column(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.VectorColumn)
def test_table_schema(self):
vec_params = [vector_column_factory() for i in range(10)]
column_params = [column_factory() for i in range(5)]
param = {
'table_name': 'test03',
'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params],
'attribute_columns': [Prepare.column(**pa) for pa in column_params],
'partition_column_names': [str(x) for x in range(2)]
}
res = Prepare.table_schema(**param)
assert isinstance(res, ttypes.TableSchema)
def test_range(self):
param = {
'start': '200',
'end': '1000'
}
res = Prepare.range(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.Range)
assert res.start_value == '200'
assert res.end_value == '1000'
def test_create_table_partition_param(self):
param = {
'table_name': fake.table_name(),
'partition_name': fake.table_name(),
'column_name_to_range': {fake.name(): range_factory() for _ in range(3)}
}
res = Prepare.create_table_partition_param(**param)
LOGGER.error('{}'.format(res))
assert isinstance(res, ttypes.CreateTablePartitionParam)
def test_delete_table_partition_param(self):
param = {
'table_name': fake.table_name(),
'partition_names': [fake.name() for i in range(5)]
}
res = Prepare.delete_table_partition_param(**param)
assert isinstance(res, ttypes.DeleteTablePartitionParam)
def test_row_record(self):
param={
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'column_name_to_attribute': {fake.name(): fake.name()}
}
res = Prepare.row_record(**param)
assert isinstance(res, ttypes.RowRecord)
def test_query_record(self):
param = {
'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]},
'selected_columns': [fake.name() for _ in range(10)],
'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]}
}
res = Prepare.query_record(**param)
assert isinstance(res, ttypes.QueryRecord)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册