提交 baf8eaee 编写于 作者: Y Yang Xuan

small verify


Former-commit-id: f11084a887b30724e7d31ce89214c03f59c828bc
上级 8623ae24
...@@ -58,12 +58,13 @@ class VectorColumn(Column): ...@@ -58,12 +58,13 @@ class VectorColumn(Column):
""" """
def __init__(self, name, def __init__(self, name,
dimension=0, dimension=0,
index_type=AbstactIndexType.RAW, index_type=None,
store_raw_vector=False): store_raw_vector=False,
type=None):
self.dimension = dimension self.dimension = dimension
self.index_type = index_type self.index_type = index_type
self.store_raw_vector = store_raw_vector self.store_raw_vector = store_raw_vector
super(VectorColumn, self).__init__(name, type=AbstractColumnType.VECTOR) super(VectorColumn, self).__init__(name, type=type)
class TableSchema(object): class TableSchema(object):
......
...@@ -2,6 +2,7 @@ import logging, logging.config ...@@ -2,6 +2,7 @@ import logging, logging.config
from thrift.transport import TSocket from thrift.transport import TSocket
from thrift.transport import TTransport from thrift.transport import TTransport
from thrift.transport.TTransport import TTransportException
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
from thrift.Thrift import TException, TApplicationException, TType from thrift.Thrift import TException, TApplicationException, TType
...@@ -38,15 +39,10 @@ class IndexType(AbstactIndexType): ...@@ -38,15 +39,10 @@ class IndexType(AbstactIndexType):
class ColumnType(AbstractColumnType): class ColumnType(AbstractColumnType):
# INVALID = 1
# INT8 = 2
# INT16 = 3
# INT32 = 4
# INT64 = 5
FLOAT32 = 6 FLOAT32 = 6
FLOAT64 = 7 FLOAT64 = 7
DATE = 8 DATE = 8
# VECTOR = 9
INVALID = TType.STOP INVALID = TType.STOP
INT8 = TType.I08 INT8 = TType.I08
...@@ -62,13 +58,12 @@ class Prepare(object): ...@@ -62,13 +58,12 @@ class Prepare(object):
def column(cls, name, type): def column(cls, name, type):
""" """
Table column param Table column param
# todo type
:param type: ColumnType, type of the column :param type: ColumnType, type of the column
:param name: str, name of the column :param name: str, name of the column
:return Column :return Column
""" """
# TODO type in Thrift, may have error
temp_column = Column(name=name, type=type) temp_column = Column(name=name, type=type)
return ttypes.Column(name=temp_column.name, type=temp_column.type) return ttypes.Column(name=temp_column.name, type=temp_column.type)
...@@ -81,7 +76,7 @@ class Prepare(object): ...@@ -81,7 +76,7 @@ class Prepare(object):
:param dimension: int64, vector dimension :param dimension: int64, vector dimension
:param index_type: IndexType :param index_type: IndexType
:param store_raw_vector: Bool, Is vector self stored in the table :param store_raw_vector: Bool
`Column`: `Column`:
:param name: Name of the column :param name: Name of the column
...@@ -124,8 +119,8 @@ class Prepare(object): ...@@ -124,8 +119,8 @@ class Prepare(object):
Name of the column Name of the column
- type: ColumnType, default=ColumnType.VECTOR, can't change - type: ColumnType, default=ColumnType.VECTOR, can't change
:param attribute_columns: List of Columns. Attribute :param attribute_columns: List of Columns. Attribute columns are Columns,
columns are Columns whose type aren't ColumnType.VECTOR whose types aren't ColumnType.VECTOR
`Column`: `Column`:
- name: str - name: str
...@@ -266,7 +261,7 @@ class MegaSearch(ConnectIntf): ...@@ -266,7 +261,7 @@ class MegaSearch(ConnectIntf):
transport = TSocket.TSocket(host=host, port=port) transport = TSocket.TSocket(host=host, port=port)
self.transport = TTransport.TBufferedTransport(transport) self.transport = TTransport.TBufferedTransport(transport)
protocol = TJSONProtocol.TJSONProtocol(transport) protocol = TBinaryProtocol.TBinaryProtocol(transport)
self.client = MegasearchService.Client(protocol) self.client = MegasearchService.Client(protocol)
try: try:
...@@ -312,8 +307,9 @@ class MegaSearch(ConnectIntf): ...@@ -312,8 +307,9 @@ class MegaSearch(ConnectIntf):
raise NotConnectError('Please Connect to the server first!') raise NotConnectError('Please Connect to the server first!')
try: try:
LOGGER.error(param)
self.client.CreateTable(param) self.client.CreateTable(param)
except (TApplicationException, TException) as e: except (TApplicationException, ) as e:
LOGGER.error('Unable to create table') LOGGER.error('Unable to create table')
return Status(Status.INVALID, str(e)) return Status(Status.INVALID, str(e))
return Status(message='Table {} created!'.format(param.table_name)) return Status(message='Table {} created!'.format(param.table_name))
...@@ -463,11 +459,10 @@ class MegaSearch(ConnectIntf): ...@@ -463,11 +459,10 @@ class MegaSearch(ConnectIntf):
# TODO How to get server version # TODO How to get server version
pass pass
def server_status(self, cmd): def server_status(self, cmd=None):
""" """
Provide server status Provide server status
:return: Server status :return: Server status
""" """
self.client.Ping(cmd) return self.client.Ping(cmd)
pass
from client.Client import MegaSearch, Prepare, IndexType, ColumnType from client.Client import MegaSearch, Prepare, IndexType, ColumnType
from client.Status import Status from client.Status import Status
import time
from megasearch.thrift import MegasearchService, ttypes
def main(): def main():
...@@ -13,31 +16,65 @@ def main(): ...@@ -13,31 +16,65 @@ def main():
is_connected = mega.connected is_connected = mega.connected
print('Connect status: {}'.format(is_connected)) print('Connect status: {}'.format(is_connected))
# # Create table with 1 vector column, 1 attribute column and 1 partition column # Create table with 1 vector column, 1 attribute column and 1 partition column
# # 1. prepare table_schema # 1. prepare table_schema
# vector_column = {
# 'name': 'fake_vec_name01', # table_schema = Prepare.table_schema(
# 'store_raw_vector': True, # table_name='fake_table_name' + time.strftime('%H%M%S'),
# 'dimension': 10
# }
# attribute_column = {
# 'name': 'fake_attri_name01',
# 'type': ColumnType.DATE,
# }
# #
# table = { # vector_columns=[Prepare.vector_column(
# 'table_name': 'fake_table_name01', # name='fake_vector_name' + time.strftime('%H%M%S'),
# 'vector_columns': [Prepare.vector_column(**vector_column)], # store_raw_vector=False,
# 'attribute_columns': [Prepare.column(**attribute_column)], # dimension=256)],
# 'partition_column_names': ['fake_attri_name01']
# }
# table_schema = Prepare.table_schema(**table)
# #
# # 2. Create Table # attribute_columns=[],
# create_status = mega.create_table(table_schema) #
# print('Create table status: {}'.format(create_status)) # partition_column_names=[]
# )
# get server version
print(mega.server_status('version'))
# show tables and their description
statu, tables = mega.show_tables()
print(tables)
for table in tables:
s,t = mega.describe_table(table)
print('table: {}'.format(t))
# Create table
# 1. create table schema
table_schema_full = MegasearchService.TableSchema(
table_name='fake' + time.strftime('%H%M%S'),
vector_column_array=[MegasearchService.VectorColumn(
base=MegasearchService.Column(
name='111',
type=ttypes.TType.I32
),
dimension=256,
)],
attribute_column_array=[],
partition_column_name_array=None
)
table_schema_empty = MegasearchService.TableSchema(
table_name='fake' + time.strftime('%H%M%S'),
vector_column_array=[MegasearchService.VectorColumn()],
attribute_column_array=[],
partition_column_name_array=None
)
# 2. Create Table
create_status = mega.create_table(table_schema_full)
print('Create table status: {}'.format(create_status))
mega.server_status('ok!') # add_vector
# Disconnect # Disconnect
discnn_status = mega.disconnect() discnn_status = mega.disconnect()
......
...@@ -14,6 +14,7 @@ from client.Exceptions import ( ...@@ -14,6 +14,7 @@ from client.Exceptions import (
from thrift.transport.TSocket import TSocket from thrift.transport.TSocket import TSocket
from megasearch.thrift import ttypes, MegasearchService from megasearch.thrift import ttypes, MegasearchService
from thrift.transport.TTransport import TTransportException
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
...@@ -38,7 +39,6 @@ def vector_column_factory(): ...@@ -38,7 +39,6 @@ def vector_column_factory():
return { return {
'name': fake.name(), 'name': fake.name(),
'dimension': fake.dim(), 'dimension': fake.dim(),
'index_type': IndexType.IVFFLAT,
'store_raw_vector': True 'store_raw_vector': True
} }
...@@ -46,7 +46,7 @@ def vector_column_factory(): ...@@ -46,7 +46,7 @@ def vector_column_factory():
def column_factory(): def column_factory():
return { return {
'name': fake.table_name(), 'name': fake.table_name(),
'type': IndexType.RAW 'type': ColumnType.INT32
} }
...@@ -146,9 +146,10 @@ class TestTable: ...@@ -146,9 +146,10 @@ class TestTable:
def test_false_create_table(self, client): def test_false_create_table(self, client):
param = table_schema_factory() param = table_schema_factory()
res = client.create_table(param) with pytest.raises(TTransportException):
LOGGER.error('{}'.format(res)) res = client.create_table(param)
assert res != Status.OK LOGGER.error('{}'.format(res))
assert res != Status.OK
@mock.patch.object(MegasearchService.Client, 'DeleteTable') @mock.patch.object(MegasearchService.Client, 'DeleteTable')
def test_delete_table(self, DeleteTable, client): def test_delete_table(self, DeleteTable, client):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册