diff --git a/python/sdk/client/Abstract.py b/python/sdk/client/Abstract.py index 93719f34415e55ed73655fdeed46d49b4a3ff6be..5a113d3114b626bfda04c189ab628e681f385056 100644 --- a/python/sdk/client/Abstract.py +++ b/python/sdk/client/Abstract.py @@ -1,70 +1,10 @@ from enum import IntEnum -from .Exceptions import ConnectParamMissingError -class AbstactIndexType(object): - RAW = '1' - IVFFLAT = '2' - - -class AbstractColumnType(object): - INVALID = 1 - INT8 = 2 - INT16 = 3 - INT32 = 4 - INT64 = 5 - FLOAT32 = 6 - FLOAT64 = 7 - DATE = 8 - VECTOR = 9 - - -class Column(object): - """ - Table column description - - :type type: ColumnType - :param type: (Required) type of the column - - :type name: str - :param name: (Required) name of the column - - """ - def __init__(self, name=None, type=AbstractColumnType.INVALID): - self.type = type - self.name = name - - -class VectorColumn(Column): - """ - Table vector column description - - :type dimension: int, int64 - :param dimension: (Required) vector dimension - - :type index_type: string IndexType - :param index_type: (Required) IndexType - - :type store_raw_vector: bool - :param store_raw_vector: (Required) Is vector self stored in the table - - `Column`: - :type name: str - :param name: (Required) Name of the column - - :type type: ColumnType - :param type: (Required) Default type is ColumnType.VECTOR, can't change - - """ - def __init__(self, name, - dimension=0, - index_type=None, - store_raw_vector=False, - type=None): - self.dimension = dimension - self.index_type = index_type - self.store_raw_vector = store_raw_vector - super(VectorColumn, self).__init__(name, type=type) +class IndexType(IntEnum): + INVALIDE = 0 + IDMAP = 1 + IVFLAT = 2 class TableSchema(object): @@ -74,30 +14,26 @@ class TableSchema(object): :type table_name: str :param table_name: (Required) name of table - :type vector_columns: list[VectorColumn] - :param vector_columns: (Required) a list of VectorColumns, - - Stores different types of vectors - - :type attribute_columns: list[Column] - :param attribute_columns: (Optional) Columns description + :type index_type: IndexType + :param index_type: (Optional) index type, default = 0 - List of `Columns` whose type isn't VECTOR + `IndexType`: 0-invalid, 1-idmap, 2-ivflat - :type partition_column_names: list[str] - :param partition_column_names: (Optional) Partition column name + :type dimension: int64 + :param dimension: (Required) dimension of vector - `Partition columns` are `attribute columns`, the number of - partition columns may be less than or equal to attribute columns, - this param only stores `column name` + :type store_raw_vector: bool + :param store_raw_vector: (Optional) default = False """ - def __init__(self, table_name, vector_columns, - attribute_columns, partition_column_names, **kwargs): + def __init__(self, table_name, + dimension=0, + index_type=IndexType.INVALIDE, + store_raw_vector=False): self.table_name = table_name - self.vector_columns = vector_columns - self.attribute_columns = attribute_columns - self.partition_column_names = partition_column_names + self.index_type = index_type + self.dimension = dimension + self.store_raw_vector = store_raw_vector class Range(object): @@ -105,10 +41,10 @@ class Range(object): Range information :type start: str - :param start: (Required) Range start value + :param start: Range start value :type end: str - :param end: (Required) Range end value + :param end: Range end value """ def __init__(self, start, end): @@ -116,97 +52,37 @@ class Range(object): self.end = end -class CreateTablePartitionParam(object): - """ - Create table partition parameters - - :type table_name: str - :param table_name: (Required) Table name, - VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition - - :type partition_name: str - :param partition_name: (Required) partition name, created partition name - - :type column_name_to_range: dict{str : Range} - :param column_name_to_range: (Required) Column name to PartitionRange dictionary - """ - # TODO Iterable - def __init__(self, table_name, partition_name, column_name_to_range): - self.table_name = table_name - self.partition_name = partition_name - self.column_name_to_range = column_name_to_range - - -class DeleteTablePartitionParam(object): - """ - Delete table partition parameters - - :type table_name: str - :param table_name: (Required) Table name - - :type partition_names: iterable, str - :param partition_names: (Required) Partition name array - - """ - # TODO Iterable - def __init__(self, table_name, partition_names): - self.table_name = table_name - self.partition_names = partition_names - - class RowRecord(object): """ Record inserted - :type column_name_to_vector: dict{str : list[float]} - :param column_name_to_vector: (Required) Column name to vector map + :type vector_data: binary str + :param vector_data: (Required) a vector - :type column_name_to_attribute: dict{str: str} - :param column_name_to_attribute: (Optional) Other attribute columns """ - def __init__(self, column_name_to_vector, column_name_to_attribute): - self.column_name_to_vector = column_name_to_vector - self.column_name_to_attribute = column_name_to_attribute - - -class QueryRecord(object): - """ - Query record - - :type column_name_to_vector: (Required) dict{str : list[float]} - :param column_name_to_vector: Query vectors, column name to vector map - - :type selected_columns: list[str] - :param selected_columns: (Optional) Output column array - - :type name_to_partition_ranges: dict{str : list[Range]} - :param name_to_partition_ranges: (Optional) Range used to select partitions - - """ - def __init__(self, column_name_to_vector, selected_columns, name_to_partition_ranges): - self.column_name_to_vector = column_name_to_vector - self.selected_columns = selected_columns - self.name_to_partition_ranges = name_to_partition_ranges + def __init__(self, vector_data): + self.vector_data = vector_data class QueryResult(object): """ Query result - :type id: int - :param id: Output result + :type id: int64 + :param id: id of the vector :type score: float :param score: Vector similarity 0 <= score <= 100 - :type column_name_to_attribute: dict{str : str} - :param column_name_to_attribute: Other columns - """ - def __init__(self, id, score, column_name_to_attribute): + def __init__(self, id, score): self.id = id self.score = score - self.column_name_to_value = column_name_to_attribute + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) class TopKQueryResult(object): @@ -220,6 +96,12 @@ class TopKQueryResult(object): def __init__(self, query_results): self.query_results = query_results + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def _abstract(): raise NotImplementedError('You need to override this function') @@ -232,114 +114,71 @@ class ConnectIntf(object): """ - @staticmethod - def create(): - """Create a connection instance and return it - should be implemented - - :return connection: Connection - """ - _abstract() - - @staticmethod - def destroy(connection): - """Destroy the connection instance - should be implemented - - :type connection: Connection - :param connection: The connection instance to be destroyed - - :return bool, return True if destroy is successful - """ - _abstract() - - def connect(self, param=None, uri=None): + def connect(self, host=None, port=None, uri=None): """ Connect method should be called before any operations Server will be connected after connect return OK - should be implemented + Should be implemented + + :type host: str + :param host: host - :type param: ConnectParam - :param param: ConnectParam + :type port: str + :param port: port :type uri: str - :param uri: uri param + :param uri: (Optional) uri - :return: Status, indicate if connect is successful + :return Status, indicate if connect is successful """ - if (not param and not uri) or (param and uri): - raise ConnectParamMissingError('You need to parse exact one param') _abstract() def connected(self): """ connected, connection status - should be implemented + Should be implemented - :return: Status, indicate if connect is successful + :return Status, indicate if connect is successful """ _abstract() def disconnect(self): """ - Disconnect, server will be disconnected after disconnect return OK - should be implemented + Disconnect, server will be disconnected after disconnect return SUCCESS + Should be implemented - :return: Status, indicate if connect is successful + :return Status, indicate if connect is successful """ _abstract() def create_table(self, param): """ Create table - should be implemented + Should be implemented :type param: TableSchema :param param: provide table information to be created - :return: Status, indicate if connect is successful + :return Status, indicate if connect is successful """ _abstract() def delete_table(self, table_name): """ Delete table - should be implemented + Should be implemented :type table_name: str :param table_name: table_name of the deleting table - :return: Status, indicate if connect is successful + :return Status, indicate if connect is successful """ _abstract() - def create_table_partition(self, param): - """ - Create table partition - should be implemented - - :type param: CreateTablePartitionParam - :param param: provide partition information - - :return: Status, indicate if table partition is created successfully - """ - _abstract() - - def delete_table_partition(self, param): - """ - Delete table partition - should be implemented - - :type param: DeleteTablePartitionParam - :param param: provide partition information to be deleted - :return: Status, indicate if partition is deleted successfully - """ - _abstract() - - def add_vector(self, table_name, records): + def add_vectors(self, table_name, records): """ Add vectors to table - should be implemented + Should be implemented :type table_name: str :param table_name: table name been inserted @@ -347,27 +186,31 @@ class ConnectIntf(object): :type records: list[RowRecord] :param records: list of vectors been inserted - :returns: + :returns Status : indicate if vectors inserted successfully ids :list of id, after inserted every vector is given a id """ _abstract() - def search_vector(self, table_name, query_records, top_k): + def search_vectors(self, table_name, query_records, query_ranges, top_k): """ Query vectors in a table - should be implemented + Should be implemented :type table_name: str :param table_name: table name been queried - :type query_records: list[QueryRecord] + :type query_records: list[RowRecord] :param query_records: all vectors going to be queried + :type query_ranges: list[Range] + :param query_ranges: Optional ranges for conditional search. + If not specified, search whole table + :type top_k: int :param top_k: how many similar vectors will be searched - :returns: + :returns Status: indicate if query is successful query_results: list[TopKQueryResult] """ @@ -376,23 +219,37 @@ class ConnectIntf(object): def describe_table(self, table_name): """ Show table information - should be implemented + Should be implemented :type table_name: str :param table_name: which table to be shown - :returns: + :returns Status: indicate if query is successful table_schema: TableSchema, given when operation is successful """ _abstract() + def get_table_row_count(self, table_name): + """ + Get table row count + Should be implemented + + :type table_name, str + :param table_name, target table name. + + :returns + Status: indicate if operation is successful + count: int, table row count + """ + _abstract() + def show_tables(self): """ Show all tables in database should be implemented - :return: + :return Status: indicate if this operation is successful tables: list[str], list of table names """ @@ -403,31 +260,28 @@ class ConnectIntf(object): Provide client version should be implemented - :return: Client version + :return: str, client version """ _abstract() - pass def server_version(self): """ Provide server version should be implemented - :return: Server version + :return: str, server version """ + _abstract() def server_status(self, cmd): """ Provide server status should be implemented - # TODO What is cmd - :type cmd - :param cmd + :type cmd, str - :return: Server status + :return: str, server status """ _abstract() - pass diff --git a/python/sdk/client/Client.py b/python/sdk/client/Client.py index 7b72902ba2b379d46ae985fabd3606cc0f9581a9..86b414a9d10741c60a85f19a248d67595f212d63 100644 --- a/python/sdk/client/Client.py +++ b/python/sdk/client/Client.py @@ -9,21 +9,20 @@ from thrift.Thrift import TException, TApplicationException, TType from megasearch.thrift import MegasearchService from megasearch.thrift import ttypes from client.Abstract import ( - ConnectIntf, TableSchema, - AbstactIndexType, AbstractColumnType, - Column, - VectorColumn, Range, - CreateTablePartitionParam, - DeleteTablePartitionParam, - RowRecord, QueryRecord, - QueryResult, TopKQueryResult + ConnectIntf, + TableSchema, + IndexType, + Range, + RowRecord, + QueryResult, + TopKQueryResult ) from client.Status import Status from client.Exceptions import ( - RepeatingConnectError, ConnectParamMissingError, + RepeatingConnectError, DisconnectNotConnectedClientError, - ParamError, NotConnectError + NotConnectError ) LOGGER = logging.getLogger(__name__) @@ -32,125 +31,35 @@ __VERSION__ = '0.0.1' __NAME__ = 'Thrift_Client' -class IndexType(AbstactIndexType): - # TODO thrift in IndexType - RAW = '1' - IVFFLAT = '2' - - -class ColumnType(AbstractColumnType): - - FLOAT32 = 6 - FLOAT64 = 7 - DATE = 8 - - INVALID = TType.STOP - INT8 = TType.I08 - INT16 = TType.I16 - INT32 = TType.I32 - INT64 = TType.I64 - VECTOR = TType.LIST - -# TODO Required and Optional -# TODO Examples -# TODO ORM class Prepare(object): @classmethod - def column(cls, name, type): - """ - Table column param - :param type: (Required) ColumnType, type of the column - :param name: (Required) str, name of the column - - :return Column - """ - temp_column = Column(name=name, type=type) - return ttypes.Column(name=temp_column.name, type=temp_column.type) - - @classmethod - def vector_column(cls, name, dimension, - index_type=IndexType.RAW, - store_raw_vector=False): + def table_schema(cls, + table_name, *, + dimension, + index_type, + store_raw_vector): """ - Table vector column description - - :param dimension: (Required) int64, vector dimension - :param index_type: (Required) IndexType - :param store_raw_vector: (Required) Bool - - `Column`: - :param name: (Required) Name of the column - :param type: (Required) Default type is ColumnType.VECTOR, can't change - - :return VectorColumn - """ - - temp = VectorColumn(name=name, dimension=dimension, - index_type=index_type, store_raw_vector=store_raw_vector) - base = ttypes.Column(name=temp.name, type=ColumnType.VECTOR) - return ttypes.VectorColumn(base=base, dimension=temp.dimension, - store_raw_vector=temp.store_raw_vector, - index_type=temp.index_type) - - # Without IndexType - # temp = VectorColumn(name=name, dimension=dimension, - # store_raw_vector=store_raw_vector) - # return ttypes.VectorColumn(base=base, dimension=temp.dimension, - # store_raw_vector=temp.store_raw_vector) - - @classmethod - def table_schema(cls, table_name, - vector_columns, - attribute_columns, - partition_column_names): - """ - - :param table_name: (Required) Name of the table - :param vector_columns: (Required) List of VectorColumns - `VectorColumn`: - - dimension: int, default = 0 - Dimension of the vector, different vector_columns' - dimension may vary - - index_type: (optional) IndexType, default=IndexType.RAW - Vector's index type - - store_raw_vector : (optional) bool, default=False - - name: str - Name of the column - - type: ColumnType, default=ColumnType.VECTOR, can't change - - :param attribute_columns: (Optional) List of Columns. Attribute columns are Columns, - whose types aren't ColumnType.VECTOR - - `Column`: - - name: str - - type: ColumnType, default=ColumnType.INVALID - - :param partition_column_names: (Optional) List of str. - - Partition columns name - indicates which attribute columns is used for partition, can - have lots of partition columns as long as: - -> No. partition_column_names <= No. attribute_columns - -> partition_column_names IN attribute_column_names + :param table_name: str, (Required) name of table + :param index_type: IndexType, (Required) index type, default = IndexType.INVALID + :param dimension: int64, (Optional) dimension of the table + :param store_raw_vector: bool, (Optional) default = False :return: TableSchema """ - temp = TableSchema(table_name,vector_columns, - attribute_columns, - partition_column_names) + temp = TableSchema(table_name,dimension, index_type, store_raw_vector) return ttypes.TableSchema(table_name=temp.table_name, - vector_column_array=temp.vector_columns, - attribute_column_array=temp.attribute_columns, - partition_column_name_array=temp.partition_column_names) + dimension=dimension, + index_type=index_type, + store_raw_vector=store_raw_vector) @classmethod def range(cls, start, end): """ - :param start: (Required) Partition range start value - :param end: (Required) Partition range end value + :param start: str, (Required) range start + :param end: str (Required) range end :return Range """ @@ -158,142 +67,66 @@ class Prepare(object): return ttypes.Range(start_value=temp.start, end_value=temp.end) @classmethod - def create_table_partition_param(cls, - table_name, - partition_name, - column_name_to_range): - """ - Create table partition parameters - :param table_name: (Required) str, Table name, - VECTOR/FLOAT32/FLOAT64 ColumnType is not allowed for partition - :param partition_name: (Required) str partition name, created partition name - :param column_name_to_range: (Required) dict, column name to partition range dictionary - - :return CreateTablePartitionParam - """ - temp = CreateTablePartitionParam(table_name=table_name, - partition_name=partition_name, - column_name_to_range=column_name_to_range) - return ttypes.CreateTablePartitionParam(table_name=temp.table_name, - partition_name=temp.partition_name, - range_map=temp.column_name_to_range) - - @classmethod - def delete_table_partition_param(cls, table_name, partition_names): - """ - Delete table partition parameters - :param table_name: (Required) Table name - :param partition_names: (Required) List of partition names - - :return DeleteTablePartitionParam - """ - temp = DeleteTablePartitionParam(table_name=table_name, - partition_names=partition_names) - return ttypes.DeleteTablePartitionParam(table_name=table_name, - partition_name_array=partition_names) - - @classmethod - def row_record(cls, column_name_to_vector, column_name_to_attribute): - """ - :param column_name_to_vector: (Required) dict{str : list[float]} - Column name to vector map - - :param column_name_to_attribute: (Optional) dict{str: str} - Other attribute columns - """ - temp = RowRecord(column_name_to_vector=column_name_to_vector, - column_name_to_attribute=column_name_to_attribute) - return ttypes.RowRecord(vector_map=temp.column_name_to_vector, - attribute_map=temp.column_name_to_attribute) - - @classmethod - def query_record(cls, column_name_to_vector, - selected_columns, name_to_partition_ranges): + def row_record(cls, vector_data): """ - :param column_name_to_vector: (Required) dict{str : list[float]} - Query vectors, column name to vector map + Record inserted - :param selected_columns: (Optional) list[str_column_name] - List of Output columns + :param vector_data: float binary str, (Required) a binary str - :param name_to_partition_ranges: (Optional) dict{str : list[Range]} - Partition Range used to search - - `Range`: - :param start: Partition range start value - :param end: Partition range end value - - :return QueryRecord """ - temp = QueryRecord(column_name_to_vector=column_name_to_vector, - selected_columns=selected_columns, - name_to_partition_ranges=name_to_partition_ranges) - return ttypes.QueryRecord(vector_map=temp.column_name_to_vector, - selected_column_array=temp.selected_columns, - partition_filter_column_map=name_to_partition_ranges) + temp = RowRecord(vector_data) + return ttypes.RowRecord(vector_data=temp.vector_data) class MegaSearch(ConnectIntf): def __init__(self): - self.transport = None - self.client = None self.status = None + self._transport = None + self._client = None def __repr__(self): return '{}'.format(self.status) - @staticmethod - def create(): - # TODO in python, maybe this method is useless - return MegaSearch() - - @staticmethod - def destroy(connection): - """Destroy the connection instance""" - # TODO in python, maybe this method is useless - - pass - def connect(self, host='localhost', port='9090', uri=None): # TODO URI - if self.status and self.status == Status(message='Connected'): + if self.status and self.status == Status.SUCCESS: raise RepeatingConnectError("You have already connected!") transport = TSocket.TSocket(host=host, port=port) - self.transport = TTransport.TBufferedTransport(transport) + self._transport = TTransport.TBufferedTransport(transport) protocol = TBinaryProtocol.TBinaryProtocol(transport) - self.client = MegasearchService.Client(protocol) + self._client = MegasearchService.Client(protocol) try: transport.open() - self.status = Status(Status.OK, 'Connected') + self.status = Status(Status.SUCCESS, 'Connected') LOGGER.info('Connected!') except (TTransport.TTransportException, TException) as e: - self.status = Status(Status.INVALID, message=str(e)) + self.status = Status(Status.CONNECT_FAILED, message=str(e)) LOGGER.error('logger.error: {}'.format(self.status)) finally: return self.status @property def connected(self): - return self.status == Status() + return self.status == Status.SUCCESS def disconnect(self): - if not self.transport: + if not self._transport: raise DisconnectNotConnectedClientError('Error') try: - self.transport.close() + self._transport.close() LOGGER.info('Client Disconnected!') self.status = None except TException as e: - return Status(Status.INVALID, str(e)) - return Status(Status.OK, 'Disconnected') + return Status(Status.PERMISSION_DENIED, str(e)) + return Status(Status.SUCCESS, 'Disconnected') def create_table(self, param): """Create table @@ -304,15 +137,14 @@ class MegaSearch(ConnectIntf): :return: Status, indicate if operation is successful """ - if not self.client: + if not self._client: raise NotConnectError('Please Connect to the server first!') try: - LOGGER.error(param) - self.client.CreateTable(param) + self._client.CreateTable(param) except (TApplicationException, ) as e: LOGGER.error('Unable to create table') - return Status(Status.INVALID, str(e)) + return Status(Status.PERMISSION_DENIED, str(e)) return Status(message='Table {} created!'.format(param.table_name)) def delete_table(self, table_name): @@ -323,48 +155,13 @@ class MegaSearch(ConnectIntf): :return: Status, indicate if operation is successful """ try: - self.client.DeleteTable(table_name) + self._client.DeleteTable(table_name) except (TApplicationException, TException) as e: LOGGER.error('Unable to delete table {}'.format(table_name)) - return Status(Status.INVALID, str(e)) + return Status(Status.PERMISSION_DENIED, str(e)) return Status(message='Table {} deleted!'.format(table_name)) - def create_table_partition(self, param): - """ - Create table partition - - :type param: CreateTablePartitionParam, provide partition information - - `Please use Prepare.create_table_partition_param generate param` - - :return: Status, indicate if table partition is created successfully - """ - try: - self.client.CreateTablePartition(param) - except (TApplicationException, TException) as e: - LOGGER.error('{}'.format(e)) - return Status(Status.INVALID, str(e)) - return Status(message='Table partition created successfully!') - - def delete_table_partition(self, param): - """ - Delete table partition - - :type param: DeleteTablePartitionParam - :param param: provide partition information to be deleted - - `Please use Prepare.delete_table_partition_param generate param` - - :return: Status, indicate if partition is deleted successfully - """ - try: - self.client.DeleteTablePartition(param) - except (TApplicationException, TException) as e: - LOGGER.error('{}'.format(e)) - return Status(Status.INVALID, str(e)) - return Status(message='Table partition deleted successfully!') - - def add_vector(self, table_name, records): + def add_vectors(self, table_name, records): """ Add vectors to table @@ -378,13 +175,13 @@ class MegaSearch(ConnectIntf): ids :list of id, after inserted every vector is given a id """ try: - ids = self.client.AddVector(table_name=table_name, record_array=records) + ids = self._client.AddVector(table_name=table_name, record_array=records) except (TApplicationException, TException) as e: LOGGER.error('{}'.format(e)) - return Status(Status.INVALID, str(e)), None - return Status(message='Vector added successfully!'), ids + return Status(Status.PERMISSION_DENIED, str(e)), None + return Status(message='Vectors added successfully!'), ids - def search_vector(self, table_name, query_records, top_k): + def search_vectors(self, table_name, top_k, query_records, query_ranges=None): """ Query vectors in a table @@ -394,20 +191,29 @@ class MegaSearch(ConnectIntf): `Please use Prepare.query_record generate QueryRecord` :param top_k: int, how many similar vectors will be searched + :param query_ranges, (Optional) list[Range], search range :returns: Status: indicate if query is successful - query_results: list[TopKQueryResult], return when operation is successful + res: list[TopKQueryResult], return when operation is successful """ - # TODO topk_query_results + res = [] try: - topk_query_results = self.client.SearchVector( - table_name=table_name, query_record_array=query_records, topk=top_k) + top_k_query_results = self._client.SearchVector( + table_name=table_name, + query_record_array=query_records, + query_range_array=query_ranges, + topk=top_k) + + if top_k_query_results: + for top_k in top_k_query_results: + res.append(TopKQueryResult([QueryResult(qr.id, qr.score) + for qr in top_k.query_result_arrays])) except (TApplicationException, TException) as e: LOGGER.error('{}'.format(e)) - return Status(Status.INVALID, str(e)), None - return Status(message='Success!'), topk_query_results + return Status(Status.PERMISSION_DENIED, str(e)), None + return Status(message='Success!'), res def describe_table(self, table_name): """ @@ -420,12 +226,14 @@ class MegaSearch(ConnectIntf): table_schema: TableSchema, return when operation is successful """ try: - thrift_table_schema = self.client.DescribeTable(table_name) + temp = self._client.DescribeTable(table_name) + + # res = TableSchema(table_name=temp.table_name, dimension=temp.dimension, + # index_type=temp.index_type, store_raw_vector=temp.store_raw_vector) except (TApplicationException, TException) as e: LOGGER.error('{}'.format(e)) - return Status(Status.INVALID, str(e)), None - # TODO Table Schema - return Status(message='Success!'), thrift_table_schema + return Status(Status.PERMISSION_DENIED, str(e)), None + return Status(message='Success!'), temp def show_tables(self): """ @@ -437,12 +245,36 @@ class MegaSearch(ConnectIntf): is successful """ try: - tables = self.client.ShowTables() + res = self._client.ShowTables() + tables = [] + if res: + tables, _ = res + except (TApplicationException, TException) as e: LOGGER.error('{}'.format(e)) - return Status(Status.INVALID, str(e)), None + return Status(Status.PERMISSION_DENIED, str(e)), None return Status(message='Success!'), tables + def get_table_row_count(self, table_name): + """ + Get table row count + + :type table_name, str + :param table_name, target table name. + + :returns: + Status: indicate if operation is successful + res: int, table row count + + """ + try: + count, _ = self._client.GetTableRowCount(table_name) + + except (TApplicationException, TException) as e: + LOGGER.error('{}'.format(e)) + return Status(Status.PERMISSION_DENIED, str(e)), None + return Status(message='Success'), count + def client_version(self): """ Provide client version @@ -457,8 +289,10 @@ class MegaSearch(ConnectIntf): :return: Server version """ - # TODO How to get server version - pass + if not self.connected: + raise NotConnectError('You have to connect first') + + return self._client.Ping('version') def server_status(self, cmd=None): """ @@ -466,4 +300,7 @@ class MegaSearch(ConnectIntf): :return: Server status """ - return self.client.Ping(cmd) + if not self.connected: + raise NotConnectError('You have to connect first') + + return self._client.Ping(cmd) diff --git a/python/sdk/client/Exceptions.py b/python/sdk/client/Exceptions.py index 88ced39b25d39488460ec5fb3747b12dd1a6f90d..30be65b5eb74658046acccef0c444b9dc0613bfc 100644 --- a/python/sdk/client/Exceptions.py +++ b/python/sdk/client/Exceptions.py @@ -2,10 +2,6 @@ class ParamError(ValueError): pass -class ConnectParamMissingError(ParamError): - pass - - class ConnectError(ValueError): pass diff --git a/python/sdk/client/Status.py b/python/sdk/client/Status.py index 7d5b6062054f746f2f045b2196525a9853821799..b60c48aa981c6d342e5cb99a3a2d459305411e9c 100644 --- a/python/sdk/client/Status.py +++ b/python/sdk/client/Status.py @@ -3,13 +3,15 @@ class Status(object): :attribute code : int (optional) default as ok :attribute message : str (optional) current status message """ - OK = 0 - INVALID = 1 - UNKNOWN_ERROR = 2 - NOT_SUPPORTED = 3 - NOT_CONNECTED = 4 + SUCCESS = 0 + CONNECT_FAILED = 1 + PERMISSION_DENIED = 2 + TABLE_NOT_EXISTS = 3 + ILLEGAL_ARGUMENT = 4 + ILLEGAL_RANGE = 5 + ILLEGAL_DIMENSION = 6 - def __init__(self, code=OK, message=None): + def __init__(self, code=SUCCESS, message=None): self.code = code self.message = message diff --git a/python/sdk/examples/connection_exp.py b/python/sdk/examples/connection_exp.py index e7d139f14717275966be954e443f559cd2a009b9..05c2cffd5a68dfa695ee821271b229d80348b241 100644 --- a/python/sdk/examples/connection_exp.py +++ b/python/sdk/examples/connection_exp.py @@ -1,79 +1,92 @@ -from client.Client import MegaSearch, Prepare, IndexType, ColumnType +from client.Client import MegaSearch, Prepare, IndexType from client.Status import Status import time +import random +import struct +from pprint import pprint from megasearch.thrift import MegasearchService, ttypes def main(): + # Get client version mega = MegaSearch() - print(mega.client_version()) + print('# Client version: {}'.format(mega.client_version())) # Connect param = {'host': '192.168.1.129', 'port': '33001'} cnn_status = mega.connect(**param) - print('Connect Status: {}'.format(cnn_status)) + print('# Connect Status: {}'.format(cnn_status)) + # Check if connected is_connected = mega.connected - print('Connect status: {}'.format(is_connected)) - - # Create table with 1 vector column, 1 attribute column and 1 partition column - # 1. prepare table_schema - - # table_schema = Prepare.table_schema( - # table_name='fake_table_name' + time.strftime('%H%M%S'), - # - # vector_columns=[Prepare.vector_column( - # name='fake_vector_name' + time.strftime('%H%M%S'), - # store_raw_vector=False, - # dimension=256)], - # - # attribute_columns=[], - # - # partition_column_names=[] - # ) - - # get server version - print(mega.server_status('version')) - print(mega.client.Ping('version')) - # show tables and their description - statu, tables = mega.show_tables() - print(tables) - - for table in tables: - s,t = mega.describe_table(table) - print('table: {}'.format(t)) + print('# Is connected: {}'.format(is_connected)) - # Create table - # 1. create table schema - table_schema_full = MegasearchService.TableSchema( - table_name='fake' + time.strftime('%H%M%S'), - - vector_column_array=[MegasearchService.VectorColumn( - base=MegasearchService.Column( - name='111', - type=ttypes.TType.LIST - ), - index_type="aaa", - dimension=256, - store_raw_vector=False, - )], - - attribute_column_array=[], + # Get server version + print('# Server version: {}'.format(mega.server_version())) - partition_column_name_array=[] - ) + # Show tables and their description + status, tables = mega.show_tables() + print('# Show tables: {}'.format(tables)) - # 2. Create Table - create_status = mega.client.CreateTable(param=table_schema_full) - print('Create table status: {}'.format(create_status)) - - # add_vector + # Create table + # 01.Prepare data + param = { + 'table_name': 'test'+ str(random.randint(0,999)), + 'dimension': 256, + 'index_type': IndexType.IDMAP, + 'store_raw_vector': False + } + + # 02.Create table + res_status = mega.create_table(Prepare.table_schema(**param)) + print('# Create table status: {}'.format(res_status)) + + # Describe table + table_name = 'test01' + res_status, table = mega.describe_table(table_name) + print('# Describe table status: {}'.format(res_status)) + print('# Describe table:{}'.format(table)) + + # Add vectors to table 'test01' + # 01. Prepare data + dim = 256 + # list of binary vectors + vectors = [Prepare.row_record(struct.pack(str(dim)+'d', + *[random.random()for _ in range(dim)])) + for _ in range(20)] + # 02. Add vectors + status, ids = mega.add_vectors(table_name=table_name, records=vectors) + print('# Add vector status: {}'.format(status)) + pprint(ids) + + # Search vectors + q_records = [Prepare.row_record(struct.pack(str(dim) + 'd', + *[random.random() for _ in range(dim)])) + for _ in range(5)] + param = { + 'table_name': 'test01', + 'query_records': q_records, + 'top_k': 10, + # 'query_ranges': None # Optional + } + sta, results = mega.search_vectors(**param) + print('# Search vectors status: {}'.format(sta)) + pprint(results) + + # Get table row count + sta, result = mega.get_table_row_count(table_name) + print('# Status: {}'.format(sta)) + print('# Count: {}'.format(result)) + + # Delete table 'test01' + res_status = mega.delete_table(table_name) + print('# Delete table status: {}'.format(res_status)) # Disconnect discnn_status = mega.disconnect() - print('Disconnect Status{}'.format(discnn_status)) + print('# Disconnect Status: {}'.format(discnn_status)) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/python/sdk/tests/TestClient.py b/python/sdk/tests/TestClient.py index 3ff6bf9e41a5829b1054fdbf06a5a9396cb6bdc0..bc6dd55018031f0ad6bcf8db6ce4d91c3b88823c 100644 --- a/python/sdk/tests/TestClient.py +++ b/python/sdk/tests/TestClient.py @@ -3,17 +3,20 @@ import pytest import mock import faker import random +import struct from faker.providers import BaseProvider -from client.Client import MegaSearch, Prepare, IndexType, ColumnType +from client.Client import MegaSearch, Prepare +from client.Abstract import IndexType, TableSchema from client.Status import Status from client.Exceptions import ( RepeatingConnectError, DisconnectNotConnectedClientError ) +from megasearch.thrift import ttypes, MegasearchService from thrift.transport.TSocket import TSocket -from megasearch.thrift import ttypes, MegasearchService +from thrift.transport import TTransport from thrift.transport.TTransport import TTransportException LOGGER = logging.getLogger(__name__) @@ -35,63 +38,37 @@ fake = faker.Faker() fake.add_provider(FakerProvider) -def vector_column_factory(): - return { - 'name': fake.name(), - 'dimension': fake.dim(), - 'store_raw_vector': True - } - - -def column_factory(): - return { - 'name': fake.table_name(), - 'type': ColumnType.INT32 - } - - def range_factory(): - return { + param = { 'start': str(random.randint(1, 10)), 'end': str(random.randint(11, 20)), } + return Prepare.range(**param) + + +def ranges_factory(): + return [range_factory() for _ in range(5)] def table_schema_factory(): - vec_params = [vector_column_factory() for i in range(10)] - column_params = [column_factory() for i in range(5)] param = { 'table_name': fake.table_name(), - 'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params], - 'attribute_columns': [Prepare.column(**pa) for pa in column_params], - 'partition_column_names': [str(x) for x in range(2)] + 'dimension': random.randint(0, 999), + 'index_type': IndexType.IDMAP, + 'store_raw_vector': False } return Prepare.table_schema(**param) -def create_table_partition_param_factory(): - param = { - 'table_name': fake.table_name(), - 'partition_name': fake.table_name(), - 'column_name_to_range': {fake.name(): range_factory() for _ in range(3)} - } - return Prepare.create_table_partition_param(**param) +def row_record_factory(dimension): + vec = [random.random() + random.randint(0,9) for _ in range(dimension)] + bin_vec = struct.pack(str(dimension) + "d", *vec) - -def delete_table_partition_param_factory(): - param = { - 'table_name': fake.table_name(), - 'partition_names': [fake.name() for i in range(5)] - } - return Prepare.delete_table_partition_param(**param) + return Prepare.row_record(vector_data=bin_vec) -def row_record_factory(): - param = { - 'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, - 'column_name_to_attribute': {fake.name(): fake.name()} - } - return Prepare.row_record(**param) +def row_records_factory(dimension): + return [row_record_factory(dimension) for _ in range(20)] class TestConnection: @@ -103,9 +80,8 @@ class TestConnection: cnn = MegaSearch() cnn.connect(**self.param) - assert cnn.status == Status.OK + assert cnn.status == Status.SUCCESS assert cnn.connected - assert isinstance(cnn.client, MegasearchService.Client) with pytest.raises(RepeatingConnectError): cnn.connect(**self.param) @@ -114,12 +90,23 @@ class TestConnection: def test_false_connect(self): cnn = MegaSearch() - cnn.connect(self.param) - assert cnn.status != Status.OK + cnn.connect(**self.param) + assert cnn.status != Status.SUCCESS + + @mock.patch.object(TTransport.TBufferedTransport, 'close') + @mock.patch.object(TSocket, 'open') + def test_disconnected(self, close, open): + close.return_value = None + open.return_value = None + + cnn = MegaSearch() + cnn.connect(**self.param) + + assert cnn.disconnect() == Status.SUCCESS def test_disconnected_error(self): cnn = MegaSearch() - cnn.connect_status = Status(Status.INVALID) + cnn.connect_status = Status(Status.PERMISSION_DENIED) with pytest.raises(DisconnectNotConnectedClientError): cnn.disconnect() @@ -142,26 +129,26 @@ class TestTable: param = table_schema_factory() res = client.create_table(param) - assert res == Status.OK + assert res == Status.SUCCESS def test_false_create_table(self, client): param = table_schema_factory() with pytest.raises(TTransportException): res = client.create_table(param) LOGGER.error('{}'.format(res)) - assert res != Status.OK + assert res != Status.SUCCESS @mock.patch.object(MegasearchService.Client, 'DeleteTable') def test_delete_table(self, DeleteTable, client): DeleteTable.return_value = None table_name = 'fake_table_name' res = client.delete_table(table_name) - assert res == Status.OK + assert res == Status.SUCCESS def test_false_delete_table(self, client): table_name = 'fake_table_name' res = client.delete_table(table_name) - assert res != Status.OK + assert res != Status.SUCCESS class TestVector: @@ -176,70 +163,46 @@ class TestVector: cnn.connect(**param) return cnn - @mock.patch.object(MegasearchService.Client, 'CreateTablePartition') - def test_create_table_partition(self, CreateTablePartition, client): - CreateTablePartition.return_value = None - - param = create_table_partition_param_factory() - res = client.create_table_partition(param) - assert res == Status.OK - - def test_false_table_partition(self, client): - param = create_table_partition_param_factory() - res = client.create_table_partition(param) - assert res != Status.OK - - @mock.patch.object(MegasearchService.Client, 'DeleteTablePartition') - def test_delete_table_partition(self, DeleteTablePartition, client): - DeleteTablePartition.return_value = None - - param = delete_table_partition_param_factory() - res = client.delete_table_partition(param) - assert res == Status.OK - - def test_false_delete_table_partition(self, client): - param = delete_table_partition_param_factory() - res = client.delete_table_partition(param) - assert res != Status.OK - @mock.patch.object(MegasearchService.Client, 'AddVector') def test_add_vector(self, AddVector, client): AddVector.return_value = None param ={ 'table_name': fake.table_name(), - 'records': [row_record_factory() for _ in range(1000)] + 'records': row_records_factory(256) } - res, ids = client.add_vector(**param) - assert res == Status.OK + res, ids = client.add_vectors(**param) + assert res == Status.SUCCESS def test_false_add_vector(self, client): param ={ 'table_name': fake.table_name(), - 'records': [row_record_factory() for _ in range(1000)] + 'records': row_records_factory(256) } - res, ids = client.add_vector(**param) - assert res != Status.OK + res, ids = client.add_vectors(**param) + assert res != Status.SUCCESS @mock.patch.object(MegasearchService.Client, 'SearchVector') def test_search_vector(self, SearchVector, client): - SearchVector.return_value = None + SearchVector.return_value = None, None param = { 'table_name': fake.table_name(), - 'query_records': [row_record_factory() for _ in range(1000)], - 'top_k': random.randint(0,10) + 'query_records': row_records_factory(256), + 'query_ranges': ranges_factory(), + 'top_k': random.randint(0, 10) } - res, results = client.search_vector(**param) - assert res == Status.OK + res, results = client.search_vectors(**param) + assert res == Status.SUCCESS def test_false_vector(self, client): param = { 'table_name': fake.table_name(), - 'query_records': [row_record_factory() for _ in range(1000)], - 'top_k': random.randint(0,10) + 'query_records': row_records_factory(256), + 'query_ranges': ranges_factory(), + 'top_k': random.randint(0, 10) } - res, results = client.search_vector(**param) - assert res != Status.OK + res, results = client.search_vectors(**param) + assert res != Status.SUCCESS @mock.patch.object(MegasearchService.Client, 'DescribeTable') def test_describe_table(self, DescribeTable, client): @@ -247,27 +210,38 @@ class TestVector: table_name = fake.table_name() res, table_schema = client.describe_table(table_name) - assert res == Status.OK - assert isinstance(table_schema, ttypes.TableSchema) + assert res == Status.SUCCESS + assert isinstance(table_schema, TableSchema) def test_false_decribe_table(self, client): table_name = fake.table_name() res, table_schema = client.describe_table(table_name) - assert res != Status.OK + assert res != Status.SUCCESS assert not table_schema @mock.patch.object(MegasearchService.Client, 'ShowTables') def test_show_tables(self, ShowTables, client): - ShowTables.return_value = [fake.table_name() for _ in range(10)] + ShowTables.return_value = [fake.table_name() for _ in range(10)], None res, tables = client.show_tables() - assert res == Status.OK + assert res == Status.SUCCESS assert isinstance(tables, list) def test_false_show_tables(self, client): res, tables = client.show_tables() - assert res != Status.OK + assert res != Status.SUCCESS assert not tables + @mock.patch.object(MegasearchService.Client, 'GetTableRowCount') + def test_get_table_row_count(self, GetTableRowCount, client): + GetTableRowCount.return_value = 22, None + res, count = client.get_table_row_count('fake_table') + assert res == Status.SUCCESS + + def test_false_get_table_row_count(self, client): + res,count = client.get_table_row_count('fake_table') + assert res != Status.SUCCESS + assert not count + def test_client_version(self, client): res = client.client_version() assert res == '0.0.1' @@ -275,34 +249,13 @@ class TestVector: class TestPrepare: - def test_column(self): - param = { - 'name': 'test01', - 'type': ColumnType.DATE - } - res = Prepare.column(**param) - LOGGER.error('{}'.format(res)) - assert res.name == 'test01' - assert res.type == ColumnType.DATE - assert isinstance(res, ttypes.Column) - - def test_vector_column(self): - param = vector_column_factory() - - res = Prepare.vector_column(**param) - LOGGER.error('{}'.format(res)) - assert isinstance(res, ttypes.VectorColumn) - def test_table_schema(self): - vec_params = [vector_column_factory() for i in range(10)] - column_params = [column_factory() for i in range(5)] - param = { - 'table_name': 'test03', - 'vector_columns': [Prepare.vector_column(**pa) for pa in vec_params], - 'attribute_columns': [Prepare.column(**pa) for pa in column_params], - 'partition_column_names': [str(x) for x in range(2)] + 'table_name': fake.table_name(), + 'dimension': random.randint(0, 999), + 'index_type': IndexType.IDMAP, + 'store_raw_vector': False } res = Prepare.table_schema(**param) assert isinstance(res, ttypes.TableSchema) @@ -319,39 +272,10 @@ class TestPrepare: assert res.start_value == '200' assert res.end_value == '1000' - def test_create_table_partition_param(self): - param = { - 'table_name': fake.table_name(), - 'partition_name': fake.table_name(), - 'column_name_to_range': {fake.name(): range_factory() for _ in range(3)} - } - res = Prepare.create_table_partition_param(**param) - LOGGER.error('{}'.format(res)) - assert isinstance(res, ttypes.CreateTablePartitionParam) - - def test_delete_table_partition_param(self): - param = { - 'table_name': fake.table_name(), - 'partition_names': [fake.name() for i in range(5)] - } - res = Prepare.delete_table_partition_param(**param) - assert isinstance(res, ttypes.DeleteTablePartitionParam) - def test_row_record(self): - param={ - 'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, - 'column_name_to_attribute': {fake.name(): fake.name()} - } - res = Prepare.row_record(**param) + vec = [random.random() + random.randint(0, 9) for _ in range(256)] + bin_vec = struct.pack(str(256) + "d", *vec) + res = Prepare.row_record(bin_vec) assert isinstance(res, ttypes.RowRecord) - - def test_query_record(self): - param = { - 'column_name_to_vector': {fake.name(): [random.random() for i in range(256)]}, - 'selected_columns': [fake.name() for _ in range(10)], - 'name_to_partition_ranges': {fake.name(): [range_factory() for _ in range(5)]} - } - res = Prepare.query_record(**param) - assert isinstance(res, ttypes.QueryRecord) - + assert isinstance(bin_vec, bytes)