提交 b47addca 编写于 作者: P peng.xu

Merge from xupeng branch 'to_merge' into 0.6.0

......@@ -26,3 +26,6 @@ cmake_build
*.lo
*.tar.gz
*.log
.coverage
*.pyc
cov_html/
FROM python:3.6
RUN apt update && apt install -y \
less \
telnet
RUN mkdir /source
WORKDIR /source
ADD ./requirements.txt ./
RUN pip install -r requirements.txt
COPY . .
CMD python mishards/main.py
#!/bin/bash
BOLD=`tput bold`
NORMAL=`tput sgr0`
YELLOW='\033[1;33m'
ENDC='\033[0m'
echo -e "${BOLD}MISHARDS_REGISTRY=${MISHARDS_REGISTRY}${ENDC}"
function build_image() {
dockerfile=$1
remote_registry=$2
tagged=$2
buildcmd="docker build -t ${tagged} -f ${dockerfile} ."
echo -e "${BOLD}$buildcmd${NORMAL}"
$buildcmd
pushcmd="docker push ${remote_registry}"
echo -e "${BOLD}$pushcmd${NORMAL}"
$pushcmd
echo -e "${YELLOW}${BOLD}Image: ${remote_registry}${NORMAL}${ENDC}"
}
case "$1" in
all)
[[ -z $MISHARDS_REGISTRY ]] && {
echo -e "${YELLOW}Error: Please set docker registry first:${ENDC}\n\t${BOLD}export MISHARDS_REGISTRY=xxxx\n${ENDC}"
exit 1
}
version=""
[[ ! -z $2 ]] && version=":${2}"
build_image "Dockerfile" "${MISHARDS_REGISTRY}${version}" "${MISHARDS_REGISTRY}"
;;
*)
echo "Usage: [option...] {base | apps}"
echo "all, Usage: build.sh all [tagname|] => {docker_registry}:\${tagname}"
;;
esac
import logging
import pytest
import grpc
from mishards import settings, db, create_app
logger = logging.getLogger(__name__)
@pytest.fixture
def app(request):
app = create_app(settings.TestingConfig)
db.drop_all()
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def started_app(app):
app.on_pre_run()
app.start(settings.SERVER_TEST_PORT)
yield app
app.stop()
import fire
from mishards import db
from sqlalchemy import and_
class DBHandler:
@classmethod
def create_all(cls):
db.create_all()
@classmethod
def drop_all(cls):
db.drop_all()
@classmethod
def fun(cls, tid):
from mishards.factories import TablesFactory, TableFilesFactory, Tables
f = db.Session.query(Tables).filter(and_(
Tables.table_id == tid,
Tables.state != Tables.TO_DELETE)
).first()
print(f)
# f1 = TableFilesFactory()
if __name__ == '__main__':
fire.Fire(DBHandler)
DEBUG=True
WOSERVER=tcp://127.0.0.1:19530
SERVER_PORT=19532
SERVER_TEST_PORT=19888
SD_PROVIDER=Static
SD_NAMESPACE=xp
SD_IN_CLUSTER=False
SD_POLL_INTERVAL=5
SD_ROSERVER_POD_PATT=.*-ro-servers-.*
SD_LABEL_SELECTOR=tier=ro-servers
SD_STATIC_HOSTS=127.0.0.1
SD_STATIC_PORT=19530
#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
SQL_ECHO=True
#SQLALCHEMY_DATABASE_TEST_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
SQLALCHEMY_DATABASE_TEST_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
SQL_TEST_ECHO=False
# TRACING_TEST_TYPE=jaeger
TRACING_TYPE=jaeger
TRACING_SERVICE_NAME=fortest
TRACING_SAMPLER_TYPE=const
TRACING_SAMPLER_PARAM=1
TRACING_LOG_PAYLOAD=True
#TRACING_SAMPLER_TYPE=probabilistic
#TRACING_SAMPLER_PARAM=0.5
import logging
from mishards import settings
logger = logging.getLogger()
from mishards.db_base import DB
db = DB()
from mishards.server import Server
grpc_server = Server()
def create_app(testing_config=None):
config = testing_config if testing_config else settings.DefaultConfig
db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO)
from mishards.connections import ConnectionMgr
connect_mgr = ConnectionMgr()
from sd import ProviderManager
sd_proiver_class = ProviderManager.get_provider(settings.SD_PROVIDER)
discover = sd_proiver_class(settings=settings.SD_PROVIDER_SETTINGS, conn_mgr=connect_mgr)
from tracing.factory import TracerFactory
from mishards.grpc_utils import GrpcSpanDecorator
tracer = TracerFactory.new_tracer(config.TRACING_TYPE, settings.TracingConfig,
span_decorator=GrpcSpanDecorator())
from mishards.routings import RouterFactory
router = RouterFactory.new_router(config.ROUTER_CLASS_NAME, connect_mgr)
grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, router=router, discover=discover)
from mishards import exception_handlers
return grpc_server
import logging
import threading
from functools import wraps
from milvus import Milvus
from mishards import (settings, exceptions)
from utils import singleton
logger = logging.getLogger(__name__)
class Connection:
def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs):
self.name = name
self.uri = uri
self.max_retry = max_retry
self.retried = 0
self.conn = Milvus()
self.error_handlers = [] if not error_handlers else error_handlers
self.on_retry_func = kwargs.get('on_retry_func', None)
# self._connect()
def __str__(self):
return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
def _connect(self, metadata=None):
try:
self.conn.connect(uri=self.uri)
except Exception as e:
if not self.error_handlers:
raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
for handler in self.error_handlers:
handler(e, metadata=metadata)
@property
def can_retry(self):
return self.retried < self.max_retry
@property
def connected(self):
return self.conn.connected()
def on_retry(self):
if self.on_retry_func:
self.on_retry_func(self)
else:
self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried))
def on_connect(self, metadata=None):
while not self.connected and self.can_retry:
self.retried += 1
self.on_retry()
self._connect(metadata=metadata)
if not self.can_retry and not self.connected:
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
metadata=metadata))
self.retried = 0
def connect(self, func, exception_handler=None):
@wraps(func)
def inner(*args, **kwargs):
self.on_connect()
try:
return func(*args, **kwargs)
except Exception as e:
if exception_handler:
exception_handler(e)
else:
raise e
return inner
@singleton
class ConnectionMgr:
def __init__(self):
self.metas = {}
self.conns = {}
@property
def conn_names(self):
return set(self.metas.keys()) - set(['WOSERVER'])
def conn(self, name, metadata, throw=False):
c = self.conns.get(name, None)
if not c:
url = self.metas.get(name, None)
if not url:
if not throw:
return None
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
threaded = {
threading.get_ident(): this_conn
}
self.conns[name] = threaded
return this_conn
tid = threading.get_ident()
rconn = c.get(tid, None)
if not rconn:
url = self.metas.get(name, None)
if not url:
if not throw:
return None
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
c[tid] = this_conn
return this_conn
return rconn
def on_new_meta(self, name, url):
logger.info('Register Connection: name={};url={}'.format(name, url))
self.metas[name] = url
def on_duplicate_meta(self, name, url):
if self.metas[name] == url:
return self.on_same_meta(name, url)
return self.on_diff_meta(name, url)
def on_same_meta(self, name, url):
# logger.warning('Register same meta: {}:{}'.format(name, url))
pass
def on_diff_meta(self, name, url):
logger.warning('Received {} with diff url={}'.format(name, url))
self.metas[name] = url
self.conns[name] = {}
def on_unregister_meta(self, name, url):
logger.info('Unregister name={};url={}'.format(name, url))
self.conns.pop(name, None)
def on_nonexisted_meta(self, name):
logger.warning('Non-existed meta: {}'.format(name))
def register(self, name, url):
meta = self.metas.get(name)
if not meta:
return self.on_new_meta(name, url)
else:
return self.on_duplicate_meta(name, url)
def unregister(self, name):
logger.info('Unregister Connection: name={}'.format(name))
url = self.metas.pop(name, None)
if url is None:
return self.on_nonexisted_meta(name)
return self.on_unregister_meta(name, url)
import logging
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm.session import Session as SessionBase
logger = logging.getLogger(__name__)
class LocalSession(SessionBase):
def __init__(self, db, autocommit=False, autoflush=True, **options):
self.db = db
bind = options.pop('bind', None) or db.engine
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options)
class DB:
Model = declarative_base()
def __init__(self, uri=None, echo=False):
self.echo = echo
uri and self.init_db(uri, echo)
self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self))
def init_db(self, uri, echo=False):
url = make_url(uri)
if url.get_backend_name() == 'sqlite':
self.engine = create_engine(url)
else:
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
pool_pre_ping=True,
echo=echo,
max_overflow=0)
self.uri = uri
self.url = url
def __str__(self):
return '<DB: backend={};database={}>'.format(self.url.get_backend_name(), self.url.database)
@property
def Session(self):
return self.session_factory()
def remove_session(self):
self.session_factory.remove()
def drop_all(self):
self.Model.metadata.drop_all(self.engine)
def create_all(self):
self.Model.metadata.create_all(self.engine)
INVALID_CODE = -1
CONNECT_ERROR_CODE = 10001
CONNECTTION_NOT_FOUND_CODE = 10002
DB_ERROR_CODE = 10003
TABLE_NOT_FOUND_CODE = 20001
INVALID_ARGUMENT_CODE = 20002
INVALID_DATE_RANGE_CODE = 20003
INVALID_TOPK_CODE = 20004
import logging
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from mishards import grpc_server as server, exceptions
logger = logging.getLogger(__name__)
def resp_handler(err, error_code):
if not isinstance(err, exceptions.BaseException):
return status_pb2.Status(error_code=error_code, reason=str(err))
status = status_pb2.Status(error_code=error_code, reason=err.message)
if err.metadata is None:
return status
resp_class = err.metadata.get('resp_class', None)
if not resp_class:
return status
if resp_class == milvus_pb2.BoolReply:
return resp_class(status=status, bool_reply=False)
if resp_class == milvus_pb2.VectorIds:
return resp_class(status=status, vector_id_array=[])
if resp_class == milvus_pb2.TopKQueryResultList:
return resp_class(status=status, topk_query_result=[])
if resp_class == milvus_pb2.TableRowCount:
return resp_class(status=status, table_row_count=-1)
if resp_class == milvus_pb2.TableName:
return resp_class(status=status, table_name=[])
if resp_class == milvus_pb2.StringReply:
return resp_class(status=status, string_reply='')
if resp_class == milvus_pb2.TableSchema:
return milvus_pb2.TableSchema(
status=status
)
if resp_class == milvus_pb2.IndexParam:
return milvus_pb2.IndexParam(
table_name=milvus_pb2.TableName(
status=status
)
)
status.error_code = status_pb2.UNEXPECTED_ERROR
return status
@server.errorhandler(exceptions.TableNotFoundError)
def TableNotFoundErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
@server.errorhandler(exceptions.InvalidTopKError)
def InvalidTopKErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_TOPK)
@server.errorhandler(exceptions.InvalidArgumentError)
def InvalidArgumentErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)
@server.errorhandler(exceptions.DBError)
def DBErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.UNEXPECTED_ERROR)
@server.errorhandler(exceptions.InvalidRangeError)
def InvalidArgumentErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_RANGE)
import mishards.exception_codes as codes
class BaseException(Exception):
code = codes.INVALID_CODE
message = 'BaseException'
def __init__(self, message='', metadata=None):
self.message = self.__class__.__name__ if not message else message
self.metadata = metadata
class ConnectionConnectError(BaseException):
code = codes.CONNECT_ERROR_CODE
class ConnectionNotFoundError(BaseException):
code = codes.CONNECTTION_NOT_FOUND_CODE
class DBError(BaseException):
code = codes.DB_ERROR_CODE
class TableNotFoundError(BaseException):
code = codes.TABLE_NOT_FOUND_CODE
class InvalidTopKError(BaseException):
code = codes.INVALID_TOPK_CODE
class InvalidArgumentError(BaseException):
code = codes.INVALID_ARGUMENT_CODE
class InvalidRangeError(BaseException):
code = codes.INVALID_DATE_RANGE_CODE
import time
import datetime
import random
import factory
from factory.alchemy import SQLAlchemyModelFactory
from faker import Faker
from faker.providers import BaseProvider
from milvus.client.types import MetricType
from mishards import db
from mishards.models import Tables, TableFiles
class FakerProvider(BaseProvider):
def this_date(self):
t = datetime.datetime.today()
return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day
factory.Faker.add_provider(FakerProvider)
class TablesFactory(SQLAlchemyModelFactory):
class Meta:
model = Tables
sqlalchemy_session = db.session_factory
sqlalchemy_session_persistence = 'commit'
id = factory.Faker('random_number', digits=16, fix_len=True)
table_id = factory.Faker('uuid4')
state = factory.Faker('random_element', elements=(0, 1))
dimension = factory.Faker('random_element', elements=(256, 512))
created_on = int(time.time())
index_file_size = 0
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
metric_type = factory.Faker('random_element', elements=(MetricType.L2, MetricType.IP))
nlist = 16384
class TableFilesFactory(SQLAlchemyModelFactory):
class Meta:
model = TableFiles
sqlalchemy_session = db.session_factory
sqlalchemy_session_persistence = 'commit'
id = factory.Faker('random_number', digits=16, fix_len=True)
table = factory.SubFactory(TablesFactory)
engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3))
file_id = factory.Faker('uuid4')
file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4))
file_size = factory.Faker('random_number')
updated_time = int(time.time())
created_on = int(time.time())
date = factory.Faker('this_date')
from grpc_opentracing import SpanDecorator
from milvus.grpc_gen import status_pb2
class GrpcSpanDecorator(SpanDecorator):
def __call__(self, span, rpc_info):
status = None
if not rpc_info.response:
return
if isinstance(rpc_info.response, status_pb2.Status):
status = rpc_info.response
else:
try:
status = rpc_info.response.status
except Exception as e:
status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR,
reason='Should not happen')
if status.error_code == 0:
return
error_log = {'event': 'error',
'request': rpc_info.request,
'response': rpc_info.response
}
span.set_tag('error', True)
span.log_kv(error_log)
def mark_grpc_method(func):
setattr(func, 'grpc_method', True)
return func
def is_grpc_method(func):
if not func:
return False
return getattr(func, 'grpc_method', False)
from milvus import Status
from functools import wraps
def error_status(func):
@wraps(func)
def inner(*args, **kwargs):
try:
results = func(*args, **kwargs)
except Exception as e:
return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None
return Status(code=0, message="Success"), results
return inner
class GrpcArgsParser(object):
@classmethod
@error_status
def parse_proto_TableSchema(cls, param):
_table_schema = {
'status': param.status,
'table_name': param.table_name,
'dimension': param.dimension,
'index_file_size': param.index_file_size,
'metric_type': param.metric_type
}
return _table_schema
@classmethod
@error_status
def parse_proto_TableName(cls, param):
return param.table_name
@classmethod
@error_status
def parse_proto_Index(cls, param):
_index = {
'index_type': param.index_type,
'nlist': param.nlist
}
return _index
@classmethod
@error_status
def parse_proto_IndexParam(cls, param):
_table_name = param.table_name
_status, _index = cls.parse_proto_Index(param.index)
if not _status.OK():
raise Exception("Argument parse error")
return _table_name, _index
@classmethod
@error_status
def parse_proto_Command(cls, param):
_cmd = param.cmd
return _cmd
@classmethod
@error_status
def parse_proto_Range(cls, param):
_start_value = param.start_value
_end_value = param.end_value
return _start_value, _end_value
@classmethod
@error_status
def parse_proto_RowRecord(cls, param):
return list(param.vector_data)
@classmethod
@error_status
def parse_proto_SearchParam(cls, param):
_table_name = param.table_name
_topk = param.topk
_nprobe = param.nprobe
_status, _range = cls.parse_proto_Range(param.query_range_array)
if not _status.OK():
raise Exception("Argument parse error")
_row_record = param.query_record_array
return _table_name, _row_record, _range, _topk
@classmethod
@error_status
def parse_proto_DeleteByRangeParam(cls, param):
_table_name = param.table_name
_range = param.range
_start_value = _range.start_value
_end_value = _range.end_value
return _table_name, _start_value, _end_value
# class GrpcArgsWrapper(object):
# @classmethod
# def proto_TableName(cls):
import logging
import opentracing
from mishards.grpc_utils import GrpcSpanDecorator, is_grpc_method
from milvus.grpc_gen import status_pb2, milvus_pb2
logger = logging.getLogger(__name__)
class FakeTracer(opentracing.Tracer):
pass
class FakeSpan(opentracing.Span):
def __init__(self, context, tracer, **kwargs):
super(FakeSpan, self).__init__(tracer, context)
self.reset()
def set_tag(self, key, value):
self.tags.append({key: value})
def log_kv(self, key_values, timestamp=None):
self.logs.append(key_values)
def reset(self):
self.tags = []
self.logs = []
class FakeRpcInfo:
def __init__(self, request, response):
self.request = request
self.response = response
class TestGrpcUtils:
def test_span_deco(self):
request = 'request'
OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success')
response = OK
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
response = milvus_pb2.BoolReply(status=OK, bool_reply=False)
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
response = 1
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 1
assert len(span.tags) == 1
response = 0
rpc_info = FakeRpcInfo(request=request, response=response)
span = FakeSpan(context=None, tracer=FakeTracer())
span_deco = GrpcSpanDecorator()
span_deco(span, rpc_info)
assert len(span.logs) == 0
assert len(span.tags) == 0
def test_is_grpc_method(self):
target = 1
assert not is_grpc_method(target)
target = None
assert not is_grpc_method(target)
import math
import sys
from bisect import bisect
if sys.version_info >= (2, 5):
import hashlib
md5_constructor = hashlib.md5
else:
import md5
md5_constructor = md5.new
class HashRing(object):
def __init__(self, nodes=None, weights=None):
"""`nodes` is a list of objects that have a proper __str__ representation.
`weights` is dictionary that sets weights to the nodes. The default
weight is that all nodes are equal.
"""
self.ring = dict()
self._sorted_keys = []
self.nodes = nodes
if not weights:
weights = {}
self.weights = weights
self._generate_circle()
def _generate_circle(self):
"""Generates the circle.
"""
total_weight = 0
for node in self.nodes:
total_weight += self.weights.get(node, 1)
for node in self.nodes:
weight = 1
if node in self.weights:
weight = self.weights.get(node)
factor = math.floor((40 * len(self.nodes) * weight) / total_weight)
for j in range(0, int(factor)):
b_key = self._hash_digest('%s-%s' % (node, j))
for i in range(0, 3):
key = self._hash_val(b_key, lambda x: x + i * 4)
self.ring[key] = node
self._sorted_keys.append(key)
self._sorted_keys.sort()
def get_node(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned.
If the hash ring is empty, `None` is returned.
"""
pos = self.get_node_pos(string_key)
if pos is None:
return None
return self.ring[self._sorted_keys[pos]]
def get_node_pos(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned
along with it's position in the ring.
If the hash ring is empty, (`None`, `None`) is returned.
"""
if not self.ring:
return None
key = self.gen_key(string_key)
nodes = self._sorted_keys
pos = bisect(nodes, key)
if pos == len(nodes):
return 0
else:
return pos
def iterate_nodes(self, string_key, distinct=True):
"""Given a string key it returns the nodes as a generator that can hold the key.
The generator iterates one time through the ring
starting at the correct position.
if `distinct` is set, then the nodes returned will be unique,
i.e. no virtual copies will be returned.
"""
if not self.ring:
yield None, None
returned_values = set()
def distinct_filter(value):
if str(value) not in returned_values:
returned_values.add(str(value))
return value
pos = self.get_node_pos(string_key)
for key in self._sorted_keys[pos:]:
val = distinct_filter(self.ring[key])
if val:
yield val
for i, key in enumerate(self._sorted_keys):
if i < pos:
val = distinct_filter(self.ring[key])
if val:
yield val
def gen_key(self, key):
"""Given a string key it returns a long value,
this long value represents a place on the hash ring.
md5 is currently used because it mixes well.
"""
b_key = self._hash_digest(key)
return self._hash_val(b_key, lambda x: x)
def _hash_val(self, b_key, entry_fn):
return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | (
b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)]
def _hash_digest(self, key):
m = md5_constructor()
key = key.encode()
m.update(key)
return m.digest()
if __name__ == '__main__':
from collections import defaultdict
servers = [
'192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212',
'192.168.0.249:11212'
]
ring = HashRing(servers)
keys = ['{}'.format(i) for i in range(100)]
mapped = defaultdict(list)
for k in keys:
server = ring.get_node(k)
mapped[server].append(k)
for k, v in mapped.items():
print(k, v)
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from mishards import (settings, create_app)
def main():
server = create_app(settings.DefaultConfig)
server.run(port=settings.SERVER_PORT)
return 0
if __name__ == '__main__':
sys.exit(main())
import logging
from sqlalchemy import (Integer, Boolean, Text,
String, BigInteger, and_, or_,
Column)
from sqlalchemy.orm import relationship, backref
from mishards import db
logger = logging.getLogger(__name__)
class TableFiles(db.Model):
FILE_TYPE_NEW = 0
FILE_TYPE_RAW = 1
FILE_TYPE_TO_INDEX = 2
FILE_TYPE_INDEX = 3
FILE_TYPE_TO_DELETE = 4
FILE_TYPE_NEW_MERGE = 5
FILE_TYPE_NEW_INDEX = 6
FILE_TYPE_BACKUP = 7
__tablename__ = 'TableFiles'
id = Column(BigInteger, primary_key=True, autoincrement=True)
table_id = Column(String(50))
engine_type = Column(Integer)
file_id = Column(String(50))
file_type = Column(Integer)
file_size = Column(Integer, default=0)
row_count = Column(Integer, default=0)
updated_time = Column(BigInteger)
created_on = Column(BigInteger)
date = Column(Integer)
table = relationship(
'Tables',
primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)',
backref=backref('files', uselist=True, lazy='dynamic')
)
class Tables(db.Model):
TO_DELETE = 1
NORMAL = 0
__tablename__ = 'Tables'
id = Column(BigInteger, primary_key=True, autoincrement=True)
table_id = Column(String(50), unique=True)
state = Column(Integer)
dimension = Column(Integer)
created_on = Column(Integer)
flag = Column(Integer, default=0)
index_file_size = Column(Integer)
engine_type = Column(Integer)
nlist = Column(Integer)
metric_type = Column(Integer)
def files_to_search(self, date_range=None):
cond = or_(
TableFiles.file_type == TableFiles.FILE_TYPE_RAW,
TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX,
TableFiles.file_type == TableFiles.FILE_TYPE_INDEX,
)
if date_range:
cond = and_(
cond,
or_(
and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range
)
)
files = self.files.filter(cond)
logger.debug('DATE_RANGE: {}'.format(date_range))
return files
import logging
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import and_
from mishards import exceptions, db
from mishards.hash_ring import HashRing
from mishards.models import Tables
logger = logging.getLogger(__name__)
class RouteManager:
ROUTER_CLASSES = {}
@classmethod
def register_router_class(cls, target):
name = target.__dict__.get('NAME', None)
name = name if name else target.__class__.__name__
cls.ROUTER_CLASSES[name] = target
return target
@classmethod
def get_router_class(cls, name):
return cls.ROUTER_CLASSES.get(name, None)
class RouterFactory:
@classmethod
def new_router(cls, name, conn_mgr, **kwargs):
router_class = RouteManager.get_router_class(name)
assert router_class
return router_class(conn_mgr, **kwargs)
class RouterMixin:
def __init__(self, conn_mgr):
self.conn_mgr = conn_mgr
def routing(self, table_name, metadata=None, **kwargs):
raise NotImplemented()
def connection(self, metadata=None):
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
if conn:
conn.on_connect(metadata=metadata)
return conn.conn
def query_conn(self, name, metadata=None):
conn = self.conn_mgr.conn(name, metadata=metadata)
if not conn:
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
conn.on_connect(metadata=metadata)
return conn.conn
@RouteManager.register_router_class
class FileBasedHashRingRouter(RouterMixin):
NAME = 'FileBasedHashRingRouter'
def __init__(self, conn_mgr, **kwargs):
super(FileBasedHashRingRouter, self).__init__(conn_mgr)
def routing(self, table_name, metadata=None, **kwargs):
range_array = kwargs.pop('range_array', None)
return self._route(table_name, range_array, metadata, **kwargs)
def _route(self, table_name, range_array, metadata=None, **kwargs):
# PXU TODO: Implement Thread-local Context
# PXU TODO: Session life mgt
try:
table = db.Session.query(Tables).filter(
and_(Tables.table_id == table_name,
Tables.state != Tables.TO_DELETE)).first()
except sqlalchemy_exc.SQLAlchemyError as e:
raise exceptions.DBError(message=str(e), metadata=metadata)
if not table:
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
files = table.files_to_search(range_array)
db.remove_session()
servers = self.conn_mgr.conn_names
logger.info('Available servers: {}'.format(servers))
ring = HashRing(servers)
routing = {}
for f in files:
target_host = ring.get_node(str(f.id))
sub = routing.get(target_host, None)
if not sub:
routing[target_host] = {'table_id': table_name, 'file_ids': []}
routing[target_host]['file_ids'].append(str(f.id))
return routing
import logging
import grpc
import time
import socket
import inspect
from urllib.parse import urlparse
from functools import wraps
from concurrent import futures
from grpc._cython import cygrpc
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
from mishards.grpc_utils import is_grpc_method
from mishards.service_handler import ServiceHandler
from mishards import settings
logger = logging.getLogger(__name__)
class Server:
def __init__(self):
self.pre_run_handlers = set()
self.grpc_methods = set()
self.error_handlers = {}
self.exit_flag = False
def init_app(self,
conn_mgr,
tracer,
router,
discover,
port=19530,
max_workers=10,
**kwargs):
self.port = int(port)
self.conn_mgr = conn_mgr
self.tracer = tracer
self.router = router
self.discover = discover
self.server_impl = grpc.server(
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
(cygrpc.ChannelArgKey.max_receive_message_length, -1)])
self.server_impl = self.tracer.decorate(self.server_impl)
self.register_pre_run_handler(self.pre_run_handler)
def pre_run_handler(self):
woserver = settings.WOSERVER
url = urlparse(woserver)
ip = socket.gethostbyname(url.hostname)
socket.inet_pton(socket.AF_INET, ip)
self.conn_mgr.register(
'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80))
def register_pre_run_handler(self, func):
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
self.pre_run_handlers.add(func)
return func
def wrap_method_with_errorhandler(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if e.__class__ in self.error_handlers:
return self.error_handlers[e.__class__](e)
raise
return wrapper
def errorhandler(self, exception):
if inspect.isclass(exception) and issubclass(exception, Exception):
def wrapper(func):
self.error_handlers[exception] = func
return func
return wrapper
return exception
def on_pre_run(self):
for handler in self.pre_run_handlers:
handler()
self.discover.start()
def start(self, port=None):
handler_class = self.decorate_handler(ServiceHandler)
add_MilvusServiceServicer_to_server(
handler_class(tracer=self.tracer,
router=self.router), self.server_impl)
self.server_impl.add_insecure_port("[::]:{}".format(
str(port or self.port)))
self.server_impl.start()
def run(self, port):
logger.info('Milvus server start ......')
port = port or self.port
self.on_pre_run()
self.start(port)
logger.info('Listening on port {}'.format(port))
try:
while not self.exit_flag:
time.sleep(5)
except KeyboardInterrupt:
self.stop()
def stop(self):
logger.info('Server is shuting down ......')
self.exit_flag = True
self.server_impl.stop(0)
self.tracer.close()
logger.info('Server is closed')
def decorate_handler(self, handler):
for key, attr in handler.__dict__.items():
if is_grpc_method(attr):
setattr(handler, key, self.wrap_method_with_errorhandler(attr))
return handler
import logging
import time
import datetime
from collections import defaultdict
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
from milvus.client.abstract import Range
from milvus.client import types as Types
from mishards import (db, settings, exceptions)
from mishards.grpc_utils import mark_grpc_method
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards import utilities
logger = logging.getLogger(__name__)
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
MAX_NPROBE = 2048
MAX_TOPK = 2048
def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs):
self.table_meta = {}
self.error_handlers = {}
self.tracer = tracer
self.router = router
self.max_workers = max_workers
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
reason="Success")
if not files_n_topk_results:
return status, []
request_results = defaultdict(list)
calc_time = time.time()
for files_collection in files_n_topk_results:
if isinstance(files_collection, tuple):
status, _ = files_collection
return status, []
for request_pos, each_request_results in enumerate(
files_collection.topk_query_result):
request_results[request_pos].extend(
each_request_results.query_result_arrays)
request_results[request_pos] = sorted(
request_results[request_pos],
key=lambda x: x.distance,
reverse=reverse)[:topk]
calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time))
results = sorted(request_results.items())
topk_query_result = []
for result in results:
query_result = TopKQueryResult(query_result_arrays=result[1])
topk_query_result.append(query_result)
return status, topk_query_result
def _do_query(self,
context,
table_id,
table_meta,
vectors,
topk,
nprobe,
range_array=None,
**kwargs):
metadata = kwargs.get('metadata', None)
range_array = [
utilities.range_to_date(r, metadata=metadata) for r in range_array
] if range_array else None
routing = {}
p_span = None if self.tracer.empty else context.get_active_span(
).context
with self.tracer.start_span('get_routing', child_of=p_span):
routing = self.router.routing(table_id,
range_array=range_array,
metadata=metadata)
logger.info('Routing: {}'.format(routing))
metadata = kwargs.get('metadata', None)
rs = []
all_topk_results = []
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
logger.info(
'Send Search Request: addr={};params={};nq={};topk={};nprobe={}'
.format(addr, query_params, len(vectors), topk, nprobe))
conn = self.router.query_conn(addr, metadata=metadata)
start = time.time()
span = kwargs.get('span', None)
span = span if span else (None if self.tracer.empty else
context.get_active_span().context)
with self.tracer.start_span('search_{}'.format(addr),
child_of=span):
ret = conn.search_vectors_in_files(
table_name=query_params['table_id'],
file_ids=query_params['file_ids'],
query_records=vectors,
top_k=topk,
nprobe=nprobe,
lazy_=True)
end = time.time()
logger.info('search_vectors_in_files takes: {}'.format(end - start))
all_topk_results.append(ret)
with self.tracer.start_span('do_search', child_of=p_span) as span:
with ThreadPoolExecutor(max_workers=self.max_workers) as pool:
for addr, params in routing.items():
res = pool.submit(search,
addr,
params,
vectors,
topk,
nprobe,
span=span)
rs.append(res)
for res in rs:
res.result()
reverse = table_meta.metric_type == Types.MetricType.IP
with self.tracer.start_span('do_merge', child_of=p_span):
return self._do_merge(all_topk_results,
topk,
reverse=reverse,
metadata=metadata)
def _create_table(self, table_schema):
return self.router.connection().create_table(table_schema)
@mark_grpc_method
def CreateTable(self, request, context):
_status, _table_schema = Parser.parse_proto_TableSchema(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('CreateTable {}'.format(_table_schema['table_name']))
_status = self._create_table(_table_schema)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _has_table(self, table_name, metadata=None):
return self.router.connection(metadata=metadata).has_table(table_name)
@mark_grpc_method
def HasTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.BoolReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
bool_reply=False)
logger.info('HasTable {}'.format(_table_name))
_status, _bool = self._has_table(_table_name,
metadata={'resp_class': milvus_pb2.BoolReply})
return milvus_pb2.BoolReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
bool_reply=_bool)
def _delete_table(self, table_name):
return self.router.connection().delete_table(table_name)
@mark_grpc_method
def DropTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('DropTable {}'.format(_table_name))
_status = self._delete_table(_table_name)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _create_index(self, table_name, index):
return self.router.connection().create_index(table_name, index)
@mark_grpc_method
def CreateIndex(self, request, context):
_status, unpacks = Parser.parse_proto_IndexParam(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
_table_name, _index = unpacks
logger.info('CreateIndex {}'.format(_table_name))
# TODO: interface create_table incompleted
_status = self._create_index(_table_name, _index)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _add_vectors(self, param, metadata=None):
return self.router.connection(metadata=metadata).add_vectors(
None, None, insert_param=param)
@mark_grpc_method
def Insert(self, request, context):
logger.info('Insert')
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
_status, _ids = self._add_vectors(
metadata={'resp_class': milvus_pb2.VectorIds}, param=request)
return milvus_pb2.VectorIds(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
vector_id_array=_ids)
@mark_grpc_method
def Search(self, request, context):
table_name = request.table_name
topk = request.topk
nprobe = request.nprobe
logger.info('Search {}: topk={} nprobe={}'.format(
table_name, topk, nprobe))
metadata = {'resp_class': milvus_pb2.TopKQueryResultList}
if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.InvalidArgumentError(
message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)
if topk > self.MAX_TOPK or topk <= 0:
raise exceptions.InvalidTopKError(
message='Invalid topk: {}'.format(topk), metadata=metadata)
table_meta = self.table_meta.get(table_name, None)
if not table_meta:
status, info = self.router.connection(
metadata=metadata).describe_table(table_name)
if not status.OK():
raise exceptions.TableNotFoundError(table_name,
metadata=metadata)
self.table_meta[table_name] = info
table_meta = info
start = time.time()
query_record_array = []
for query_record in request.query_record_array:
query_record_array.append(list(query_record.vector_data))
query_range_array = []
for query_range in request.query_range_array:
query_range_array.append(
Range(query_range.start_value, query_range.end_value))
status, results = self._do_query(context,
table_name,
table_meta,
query_record_array,
topk,
nprobe,
query_range_array,
metadata=metadata)
now = time.time()
logger.info('SearchVector takes: {}'.format(now - start))
topk_result_list = milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=status.error_code,
reason=status.reason),
topk_query_result=results)
return topk_result_list
@mark_grpc_method
def SearchInFiles(self, request, context):
raise NotImplemented()
def _describe_table(self, table_name, metadata=None):
return self.router.connection(metadata=metadata).describe_table(table_name)
@mark_grpc_method
def DescribeTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.TableSchema(status=status_pb2.Status(
error_code=_status.code, reason=_status.message), )
metadata = {'resp_class': milvus_pb2.TableSchema}
logger.info('DescribeTable {}'.format(_table_name))
_status, _table = self._describe_table(metadata=metadata,
table_name=_table_name)
if _status.OK():
return milvus_pb2.TableSchema(
table_name=_table_name,
index_file_size=_table.index_file_size,
dimension=_table.dimension,
metric_type=_table.metric_type,
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
)
return milvus_pb2.TableSchema(
table_name=_table_name,
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
)
def _count_table(self, table_name, metadata=None):
return self.router.connection(
metadata=metadata).get_table_row_count(table_name)
@mark_grpc_method
def CountTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
status = status_pb2.Status(error_code=_status.code,
reason=_status.message)
return milvus_pb2.TableRowCount(status=status)
logger.info('CountTable {}'.format(_table_name))
metadata = {'resp_class': milvus_pb2.TableRowCount}
_status, _count = self._count_table(_table_name, metadata=metadata)
return milvus_pb2.TableRowCount(
status=status_pb2.Status(error_code=_status.code,
reason=_status.message),
table_row_count=_count if isinstance(_count, int) else -1)
def _get_server_version(self, metadata=None):
return self.router.connection(metadata=metadata).server_version()
@mark_grpc_method
def Cmd(self, request, context):
_status, _cmd = Parser.parse_proto_Command(request)
logger.info('Cmd: {}'.format(_cmd))
if not _status.OK():
return milvus_pb2.StringReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message))
metadata = {'resp_class': milvus_pb2.StringReply}
if _cmd == 'version':
_status, _reply = self._get_server_version(metadata=metadata)
else:
_status, _reply = self.router.connection(
metadata=metadata).server_status()
return milvus_pb2.StringReply(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
string_reply=_reply)
def _show_tables(self, metadata=None):
return self.router.connection(metadata=metadata).show_tables()
@mark_grpc_method
def ShowTables(self, request, context):
logger.info('ShowTables')
metadata = {'resp_class': milvus_pb2.TableName}
_status, _results = self._show_tables(metadata=metadata)
return milvus_pb2.TableNameList(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
table_names=_results)
def _delete_by_range(self, table_name, start_date, end_date):
return self.router.connection().delete_vectors_by_range(table_name,
start_date,
end_date)
@mark_grpc_method
def DeleteByRange(self, request, context):
_status, unpacks = \
Parser.parse_proto_DeleteByRangeParam(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
_table_name, _start_date, _end_date = unpacks
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date,
_end_date))
_status = self._delete_by_range(_table_name, _start_date, _end_date)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _preload_table(self, table_name):
return self.router.connection().preload_table(table_name)
@mark_grpc_method
def PreloadTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('PreloadTable {}'.format(_table_name))
_status = self._preload_table(_table_name)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
def _describe_index(self, table_name, metadata=None):
return self.router.connection(metadata=metadata).describe_index(table_name)
@mark_grpc_method
def DescribeIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.IndexParam(status=status_pb2.Status(
error_code=_status.code, reason=_status.message))
metadata = {'resp_class': milvus_pb2.IndexParam}
logger.info('DescribeIndex {}'.format(_table_name))
_status, _index_param = self._describe_index(table_name=_table_name,
metadata=metadata)
if not _index_param:
return milvus_pb2.IndexParam(status=status_pb2.Status(
error_code=_status.code, reason=_status.message))
_index = milvus_pb2.Index(index_type=_index_param._index_type,
nlist=_index_param._nlist)
return milvus_pb2.IndexParam(status=status_pb2.Status(
error_code=_status.code, reason=_status.message),
table_name=_table_name,
index=_index)
def _drop_index(self, table_name):
return self.router.connection().drop_index(table_name)
@mark_grpc_method
def DropIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
logger.info('DropIndex {}'.format(_table_name))
_status = self._drop_index(_table_name)
return status_pb2.Status(error_code=_status.code,
reason=_status.message)
import sys
import os
from environs import Env
env = Env()
FROM_EXAMPLE = env.bool('FROM_EXAMPLE', False)
if FROM_EXAMPLE:
from dotenv import load_dotenv
load_dotenv('./mishards/.env.example')
else:
env.read_env()
DEBUG = env.bool('DEBUG', False)
LOG_LEVEL = env.str('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO')
LOG_PATH = env.str('LOG_PATH', '/tmp/mishards')
LOG_NAME = env.str('LOG_NAME', 'logfile')
TIMEZONE = env.str('TIMEZONE', 'UTC')
from utils.logger_helper import config
config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
TIMEOUT = env.int('TIMEOUT', 60)
MAX_RETRY = env.int('MAX_RETRY', 3)
SERVER_PORT = env.int('SERVER_PORT', 19530)
SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530)
WOSERVER = env.str('WOSERVER')
SD_PROVIDER_SETTINGS = None
SD_PROVIDER = env.str('SD_PROVIDER', 'Kubernetes')
if SD_PROVIDER == 'Kubernetes':
from sd.kubernetes_provider import KubernetesProviderSettings
SD_PROVIDER_SETTINGS = KubernetesProviderSettings(
namespace=env.str('SD_NAMESPACE', ''),
in_cluster=env.bool('SD_IN_CLUSTER', False),
poll_interval=env.int('SD_POLL_INTERVAL', 5),
pod_patt=env.str('SD_ROSERVER_POD_PATT', ''),
label_selector=env.str('SD_LABEL_SELECTOR', ''),
port=env.int('SD_PORT', 19530))
elif SD_PROVIDER == 'Static':
from sd.static_provider import StaticProviderSettings
SD_PROVIDER_SETTINGS = StaticProviderSettings(
hosts=env.list('SD_STATIC_HOSTS', []),
port=env.int('SD_STATIC_PORT', 19530))
# TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530')
class TracingConfig:
TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards')
TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True)
TRACING_LOG_PAYLOAD = env.bool('TRACING_LOG_PAYLOAD', False)
TRACING_CONFIG = {
'sampler': {
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
'param': env.str('TRACING_SAMPLER_PARAM', "1"),
},
'local_agent': {
'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'),
'reporting_port': env.str('TRACING_REPORTING_PORT', '5775')
},
'logging': env.bool('TRACING_LOGGING', True)
}
DEFAULT_TRACING_CONFIG = {
'sampler': {
'type': env.str('TRACING_SAMPLER_TYPE', 'const'),
'param': env.str('TRACING_SAMPLER_PARAM', "0"),
}
}
class DefaultConfig:
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
SQL_ECHO = env.bool('SQL_ECHO', False)
TRACING_TYPE = env.str('TRACING_TYPE', '')
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter')
class TestingConfig(DefaultConfig):
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI', '')
SQL_ECHO = env.bool('SQL_TEST_ECHO', False)
TRACING_TYPE = env.str('TRACING_TEST_TYPE', '')
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter')
if __name__ == '__main__':
import logging
logger = logging.getLogger(__name__)
logger.debug('DEBUG')
logger.info('INFO')
logger.warn('WARN')
logger.error('ERROR')
import logging
import pytest
import mock
from milvus import Milvus
from mishards.connections import (ConnectionMgr, Connection)
from mishards import exceptions
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestConnection:
def test_manager(self):
mgr = ConnectionMgr()
mgr.register('pod1', '111')
mgr.register('pod2', '222')
mgr.register('pod2', '222')
mgr.register('pod2', '2222')
assert len(mgr.conn_names) == 2
mgr.unregister('pod1')
assert len(mgr.conn_names) == 1
mgr.unregister('pod2')
assert len(mgr.conn_names) == 0
mgr.register('WOSERVER', 'xxxx')
assert len(mgr.conn_names) == 0
assert not mgr.conn('XXXX', None)
with pytest.raises(exceptions.ConnectionNotFoundError):
mgr.conn('XXXX', None, True)
mgr.conn('WOSERVER', None)
def test_connection(self):
class Conn:
def __init__(self, state):
self.state = state
def connect(self, uri):
return self.state
def connected(self):
return self.state
FAIL_CONN = Conn(False)
PASS_CONN = Conn(True)
class Retry:
def __init__(self):
self.times = 0
def __call__(self, conn):
self.times += 1
logger.info('Retrying {}'.format(self.times))
class Func():
def __init__(self):
self.executed = False
def __call__(self):
self.executed = True
max_retry = 3
RetryObj = Retry()
c = Connection('client',
uri='xx',
max_retry=max_retry,
on_retry_func=RetryObj)
c.conn = FAIL_CONN
ff = Func()
this_connect = c.connect(func=ff)
with pytest.raises(exceptions.ConnectionConnectError):
this_connect()
assert RetryObj.times == max_retry
assert not ff.executed
RetryObj = Retry()
c.conn = PASS_CONN
this_connect = c.connect(func=ff)
this_connect()
assert ff.executed
assert RetryObj.times == 0
this_connect = c.connect(func=None)
with pytest.raises(TypeError):
this_connect()
errors = []
def error_handler(err):
errors.append(err)
this_connect = c.connect(func=None, exception_handler=error_handler)
this_connect()
assert len(errors) == 1
import logging
import pytest
from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory
from mishards import db, create_app, settings
from mishards.factories import (
Tables, TableFiles,
TablesFactory, TableFilesFactory
)
logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestModels:
def test_files_to_search(self):
table = TablesFactory()
new_files_cnt = 5
to_index_cnt = 10
raw_cnt = 20
backup_cnt = 12
to_delete_cnt = 9
index_cnt = 8
new_index_cnt = 6
new_merge_cnt = 11
new_files = TableFilesFactory.create_batch(new_files_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW, date=110)
to_index_files = TableFilesFactory.create_batch(to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX, date=110)
raw_files = TableFilesFactory.create_batch(raw_cnt, table=table, file_type=TableFiles.FILE_TYPE_RAW, date=120)
backup_files = TableFilesFactory.create_batch(backup_cnt, table=table, file_type=TableFiles.FILE_TYPE_BACKUP, date=110)
index_files = TableFilesFactory.create_batch(index_cnt, table=table, file_type=TableFiles.FILE_TYPE_INDEX, date=110)
new_index_files = TableFilesFactory.create_batch(new_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_INDEX, date=110)
new_merge_files = TableFilesFactory.create_batch(new_merge_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_MERGE, date=110)
to_delete_files = TableFilesFactory.create_batch(to_delete_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_DELETE, date=110)
assert table.files_to_search().count() == raw_cnt + index_cnt + to_index_cnt
assert table.files_to_search([(100, 115)]).count() == index_cnt + to_index_cnt
assert table.files_to_search([(111, 120)]).count() == 0
assert table.files_to_search([(111, 121)]).count() == raw_cnt
assert table.files_to_search([(110, 121)]).count() == raw_cnt + index_cnt + to_index_cnt
import logging
import pytest
import mock
import datetime
import random
import faker
import inspect
from milvus import Milvus
from milvus.client.types import Status, IndexType, MetricType
from milvus.client.abstract import IndexParam, TableSchema
from milvus.grpc_gen import status_pb2, milvus_pb2
from mishards import db, create_app, settings
from mishards.service_handler import ServiceHandler
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables
from mishards.routings import RouterMixin
logger = logging.getLogger(__name__)
OK = Status(code=Status.SUCCESS, message='Success')
BAD = Status(code=Status.PERMISSION_DENIED, message='Fail')
@pytest.mark.usefixtures('started_app')
class TestServer:
@property
def client(self):
m = Milvus()
m.connect(host='localhost', port=settings.SERVER_TEST_PORT)
return m
def test_server_start(self, started_app):
assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER
def test_cmd(self, started_app):
ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK,
''))
status, _ = self.client.server_version()
assert status.OK()
Parser.parse_proto_Command = mock.MagicMock(return_value=(BAD, 'cmd'))
status, _ = self.client.server_version()
assert not status.OK()
def test_drop_index(self, started_app):
table_name = inspect.currentframe().f_code.co_name
ServiceHandler._drop_index = mock.MagicMock(return_value=OK)
status = self.client.drop_index(table_name)
assert status.OK()
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status = self.client.drop_index(table_name)
assert not status.OK()
def test_describe_index(self, started_app):
table_name = inspect.currentframe().f_code.co_name
index_type = IndexType.FLAT
nlist = 1
index_param = IndexParam(table_name=table_name,
index_type=index_type,
nlist=nlist)
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._describe_index = mock.MagicMock(
return_value=(OK, index_param))
status, ret = self.client.describe_index(table_name)
assert status.OK()
assert ret._table_name == index_param._table_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status, _ = self.client.describe_index(table_name)
assert not status.OK()
def test_preload(self, started_app):
table_name = inspect.currentframe().f_code.co_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._preload_table = mock.MagicMock(return_value=OK)
status = self.client.preload_table(table_name)
assert status.OK()
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status = self.client.preload_table(table_name)
assert not status.OK()
@pytest.mark.skip
def test_delete_by_range(self, started_app):
table_name = inspect.currentframe().f_code.co_name
unpacked = table_name, datetime.datetime.today(
), datetime.datetime.today()
Parser.parse_proto_DeleteByRangeParam = mock.MagicMock(
return_value=(OK, unpacked))
ServiceHandler._delete_by_range = mock.MagicMock(return_value=OK)
status = self.client.delete_vectors_by_range(
*unpacked)
assert status.OK()
Parser.parse_proto_DeleteByRangeParam = mock.MagicMock(
return_value=(BAD, unpacked))
status = self.client.delete_vectors_by_range(
*unpacked)
assert not status.OK()
def test_count_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
count = random.randint(100, 200)
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._count_table = mock.MagicMock(return_value=(OK, count))
status, ret = self.client.get_table_row_count(table_name)
assert status.OK()
assert ret == count
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status, _ = self.client.get_table_row_count(table_name)
assert not status.OK()
def test_show_tables(self, started_app):
tables = ['t1', 't2']
ServiceHandler._show_tables = mock.MagicMock(return_value=(OK, tables))
status, ret = self.client.show_tables()
assert status.OK()
assert ret == tables
def test_describe_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
dimension = 128
nlist = 1
table_schema = TableSchema(table_name=table_name,
index_file_size=100,
metric_type=MetricType.L2,
dimension=dimension)
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_schema.table_name))
ServiceHandler._describe_table = mock.MagicMock(
return_value=(OK, table_schema))
status, _ = self.client.describe_table(table_name)
assert status.OK()
ServiceHandler._describe_table = mock.MagicMock(
return_value=(BAD, table_schema))
status, _ = self.client.describe_table(table_name)
assert not status.OK()
Parser.parse_proto_TableName = mock.MagicMock(return_value=(BAD,
'cmd'))
status, ret = self.client.describe_table(table_name)
assert not status.OK()
def test_insert(self, started_app):
table_name = inspect.currentframe().f_code.co_name
vectors = [[random.random() for _ in range(16)] for _ in range(10)]
ids = [random.randint(1000000, 20000000) for _ in range(10)]
ServiceHandler._add_vectors = mock.MagicMock(return_value=(OK, ids))
status, ret = self.client.add_vectors(
table_name=table_name, records=vectors)
assert status.OK()
assert ids == ret
def test_create_index(self, started_app):
table_name = inspect.currentframe().f_code.co_name
unpacks = table_name, None
Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(OK,
unpacks))
ServiceHandler._create_index = mock.MagicMock(return_value=OK)
status = self.client.create_index(table_name=table_name)
assert status.OK()
Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(BAD,
None))
status = self.client.create_index(table_name=table_name)
assert not status.OK()
def test_drop_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._delete_table = mock.MagicMock(return_value=OK)
status = self.client.delete_table(table_name=table_name)
assert status.OK()
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status = self.client.delete_table(table_name=table_name)
assert not status.OK()
def test_has_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(OK, table_name))
ServiceHandler._has_table = mock.MagicMock(return_value=(OK, True))
has = self.client.has_table(table_name=table_name)
assert has
Parser.parse_proto_TableName = mock.MagicMock(
return_value=(BAD, table_name))
status, has = self.client.has_table(table_name=table_name)
assert not status.OK()
assert not has
def test_create_table(self, started_app):
table_name = inspect.currentframe().f_code.co_name
dimension = 128
table_schema = dict(table_name=table_name,
index_file_size=100,
metric_type=MetricType.L2,
dimension=dimension)
ServiceHandler._create_table = mock.MagicMock(return_value=OK)
status = self.client.create_table(table_schema)
assert status.OK()
Parser.parse_proto_TableSchema = mock.MagicMock(return_value=(BAD,
None))
status = self.client.create_table(table_schema)
assert not status.OK()
def random_data(self, n, dimension):
return [[random.random() for _ in range(dimension)] for _ in range(n)]
def test_search(self, started_app):
table_name = inspect.currentframe().f_code.co_name
to_index_cnt = random.randint(10, 20)
table = TablesFactory(table_id=table_name, state=Tables.NORMAL)
to_index_files = TableFilesFactory.create_batch(
to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX)
topk = random.randint(5, 10)
nq = random.randint(5, 10)
param = {
'table_name': table_name,
'query_records': self.random_data(nq, table.dimension),
'top_k': topk,
'nprobe': 2049
}
result = [
milvus_pb2.TopKQueryResult(query_result_arrays=[
milvus_pb2.QueryResult(id=i, distance=random.random())
for i in range(topk)
]) for i in range(nq)
]
mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status(
error_code=status_pb2.SUCCESS, reason="Success"),
topk_query_result=result)
table_schema = TableSchema(table_name=table_name,
index_file_size=table.index_file_size,
metric_type=table.metric_type,
dimension=table.dimension)
status, _ = self.client.search_vectors(**param)
assert status.code == Status.ILLEGAL_ARGUMENT
param['nprobe'] = 2048
RouterMixin.connection = mock.MagicMock(return_value=Milvus())
RouterMixin.query_conn = mock.MagicMock(return_value=Milvus())
Milvus.describe_table = mock.MagicMock(return_value=(BAD,
table_schema))
status, ret = self.client.search_vectors(**param)
assert status.code == Status.TABLE_NOT_EXISTS
Milvus.describe_table = mock.MagicMock(return_value=(OK, table_schema))
Milvus.search_vectors_in_files = mock.MagicMock(
return_value=mock_results)
status, ret = self.client.search_vectors(**param)
assert status.OK()
assert len(ret) == nq
import datetime
from mishards import exceptions
def format_date(start, end):
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day,
(end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
def range_to_date(range_obj, metadata=None):
try:
start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d')
end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d')
assert start < end
except (ValueError, AssertionError):
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
range_obj.start_date, range_obj.end_date),
metadata=metadata)
return format_date(start, end)
environs==4.2.0
factory-boy==2.12.0
Faker==1.0.7
fire==0.1.3
google-auth==1.6.3
grpcio==1.22.0
grpcio-tools==1.22.0
kubernetes==10.0.1
MarkupSafe==1.1.1
marshmallow==2.19.5
pymysql==0.9.3
protobuf==3.9.1
py==1.8.0
pyasn1==0.4.7
pyasn1-modules==0.2.6
pylint==2.3.1
pymilvus-test==0.2.28
#pymilvus==0.2.0
pyparsing==2.4.0
pytest==4.6.3
pytest-level==0.1.1
pytest-print==0.1.2
pytest-repeat==0.8.0
pytest-timeout==1.3.3
python-dateutil==2.8.0
python-dotenv==0.10.3
pytz==2019.1
requests==2.22.0
requests-oauthlib==1.2.0
rsa==4.0
six==1.12.0
SQLAlchemy==1.3.5
urllib3==1.25.3
jaeger-client>=3.4.0
grpcio-opentracing>=1.0
mock==2.0.0
import logging
import inspect
# from utils import singleton
logger = logging.getLogger(__name__)
class ProviderManager:
PROVIDERS = {}
@classmethod
def register_service_provider(cls, target):
if inspect.isfunction(target):
cls.PROVIDERS[target.__name__] = target
elif inspect.isclass(target):
name = target.__dict__.get('NAME', None)
name = name if name else target.__class__.__name__
cls.PROVIDERS[name] = target
else:
assert False, 'Cannot register_service_provider for: {}'.format(target)
return target
@classmethod
def get_provider(cls, name):
return cls.PROVIDERS.get(name, None)
from sd import kubernetes_provider, static_provider
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(
os.path.abspath(__file__))))
import re
import logging
import time
import copy
import threading
import queue
import enum
from kubernetes import client, config, watch
from utils import singleton
from sd import ProviderManager
logger = logging.getLogger(__name__)
INCLUSTER_NAMESPACE_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/namespace'
class EventType(enum.Enum):
PodHeartBeat = 1
Watch = 2
class K8SMixin:
def __init__(self, namespace, in_cluster=False, **kwargs):
self.namespace = namespace
self.in_cluster = in_cluster
self.kwargs = kwargs
self.v1 = kwargs.get('v1', None)
if not self.namespace:
self.namespace = open(INCLUSTER_NAMESPACE_PATH).read()
if not self.v1:
config.load_incluster_config(
) if self.in_cluster else config.load_kube_config()
self.v1 = client.CoreV1Api()
class K8SHeartbeatHandler(threading.Thread, K8SMixin):
def __init__(self,
message_queue,
namespace,
label_selector,
in_cluster=False,
**kwargs):
K8SMixin.__init__(self,
namespace=namespace,
in_cluster=in_cluster,
**kwargs)
threading.Thread.__init__(self)
self.queue = message_queue
self.terminate = False
self.label_selector = label_selector
self.poll_interval = kwargs.get('poll_interval', 5)
def run(self):
while not self.terminate:
try:
pods = self.v1.list_namespaced_pod(
namespace=self.namespace,
label_selector=self.label_selector)
event_message = {'eType': EventType.PodHeartBeat, 'events': []}
for item in pods.items:
pod = self.v1.read_namespaced_pod(name=item.metadata.name,
namespace=self.namespace)
name = pod.metadata.name
ip = pod.status.pod_ip
phase = pod.status.phase
reason = pod.status.reason
message = pod.status.message
ready = True if phase == 'Running' else False
pod_event = dict(pod=name,
ip=ip,
ready=ready,
reason=reason,
message=message)
event_message['events'].append(pod_event)
self.queue.put(event_message)
except Exception as exc:
logger.error(exc)
time.sleep(self.poll_interval)
def stop(self):
self.terminate = True
class K8SEventListener(threading.Thread, K8SMixin):
def __init__(self, message_queue, namespace, in_cluster=False, **kwargs):
K8SMixin.__init__(self,
namespace=namespace,
in_cluster=in_cluster,
**kwargs)
threading.Thread.__init__(self)
self.queue = message_queue
self.terminate = False
self.at_start_up = True
self._stop_event = threading.Event()
def stop(self):
self.terminate = True
self._stop_event.set()
def run(self):
resource_version = ''
w = watch.Watch()
for event in w.stream(self.v1.list_namespaced_event,
namespace=self.namespace,
field_selector='involvedObject.kind=Pod'):
if self.terminate:
break
resource_version = int(event['object'].metadata.resource_version)
info = dict(
eType=EventType.Watch,
pod=event['object'].involved_object.name,
reason=event['object'].reason,
message=event['object'].message,
start_up=self.at_start_up,
)
self.at_start_up = False
# logger.info('Received event: {}'.format(info))
self.queue.put(info)
class EventHandler(threading.Thread):
def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs):
threading.Thread.__init__(self)
self.mgr = mgr
self.queue = message_queue
self.kwargs = kwargs
self.terminate = False
self.pod_patt = re.compile(pod_patt)
self.namespace = namespace
def stop(self):
self.terminate = True
def on_drop(self, event, **kwargs):
pass
def on_pod_started(self, event, **kwargs):
try_cnt = 3
pod = None
while try_cnt > 0:
try_cnt -= 1
try:
pod = self.mgr.v1.read_namespaced_pod(name=event['pod'],
namespace=self.namespace)
if not pod.status.pod_ip:
time.sleep(0.5)
continue
break
except client.rest.ApiException as exc:
time.sleep(0.5)
if try_cnt <= 0 and not pod:
if not event['start_up']:
logger.error('Pod {} is started but cannot read pod'.format(
event['pod']))
return
elif try_cnt <= 0 and not pod.status.pod_ip:
logger.warning('NoPodIPFoundError')
return
logger.info('Register POD {} with IP {}'.format(
pod.metadata.name, pod.status.pod_ip))
self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip)
def on_pod_killing(self, event, **kwargs):
logger.info('Unregister POD {}'.format(event['pod']))
self.mgr.delete_pod(name=event['pod'])
def on_pod_heartbeat(self, event, **kwargs):
names = self.mgr.conn_mgr.conn_names
running_names = set()
for each_event in event['events']:
if each_event['ready']:
self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip'])
running_names.add(each_event['pod'])
else:
self.mgr.delete_pod(name=each_event['pod'])
to_delete = names - running_names
for name in to_delete:
self.mgr.delete_pod(name)
logger.info(self.mgr.conn_mgr.conn_names)
def handle_event(self, event):
if event['eType'] == EventType.PodHeartBeat:
return self.on_pod_heartbeat(event)
if not event or (event['reason'] not in ('Started', 'Killing')):
return self.on_drop(event)
if not re.match(self.pod_patt, event['pod']):
return self.on_drop(event)
logger.info('Handling event: {}'.format(event))
if event['reason'] == 'Started':
return self.on_pod_started(event)
return self.on_pod_killing(event)
def run(self):
while not self.terminate:
try:
event = self.queue.get(timeout=1)
self.handle_event(event)
except queue.Empty:
continue
class KubernetesProviderSettings:
def __init__(self, namespace, pod_patt, label_selector, in_cluster,
poll_interval, port=None, **kwargs):
self.namespace = namespace
self.pod_patt = pod_patt
self.label_selector = label_selector
self.in_cluster = in_cluster
self.poll_interval = poll_interval
self.port = int(port) if port else 19530
@singleton
@ProviderManager.register_service_provider
class KubernetesProvider(object):
NAME = 'Kubernetes'
def __init__(self, settings, conn_mgr, **kwargs):
self.namespace = settings.namespace
self.pod_patt = settings.pod_patt
self.label_selector = settings.label_selector
self.in_cluster = settings.in_cluster
self.poll_interval = settings.poll_interval
self.port = settings.port
self.kwargs = kwargs
self.queue = queue.Queue()
self.conn_mgr = conn_mgr
if not self.namespace:
self.namespace = open(incluster_namespace_path).read()
config.load_incluster_config(
) if self.in_cluster else config.load_kube_config()
self.v1 = client.CoreV1Api()
self.listener = K8SEventListener(message_queue=self.queue,
namespace=self.namespace,
in_cluster=self.in_cluster,
v1=self.v1,
**kwargs)
self.pod_heartbeater = K8SHeartbeatHandler(
message_queue=self.queue,
namespace=self.namespace,
label_selector=self.label_selector,
in_cluster=self.in_cluster,
v1=self.v1,
poll_interval=self.poll_interval,
**kwargs)
self.event_handler = EventHandler(mgr=self,
message_queue=self.queue,
namespace=self.namespace,
pod_patt=self.pod_patt,
**kwargs)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
def delete_pod(self, name):
self.conn_mgr.unregister(name)
def start(self):
self.listener.daemon = True
self.listener.start()
self.event_handler.start()
self.pod_heartbeater.start()
def stop(self):
self.listener.stop()
self.pod_heartbeater.stop()
self.event_handler.stop()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
class Connect:
def register(self, name, value):
logger.error('Register: {} - {}'.format(name, value))
def unregister(self, name):
logger.error('Unregister: {}'.format(name))
@property
def conn_names(self):
return set()
connect_mgr = Connect()
settings = KubernetesProviderSettings(namespace='xp',
pod_patt=".*-ro-servers-.*",
label_selector='tier=ro-servers',
poll_interval=5,
in_cluster=False)
provider_class = ProviderManager.get_provider('Kubernetes')
t = provider_class(conn_mgr=connect_mgr, settings=settings)
t.start()
cnt = 100
while cnt > 0:
time.sleep(2)
cnt -= 1
t.stop()
import os
import sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import socket
from utils import singleton
from sd import ProviderManager
class StaticProviderSettings:
def __init__(self, hosts, port=None):
self.hosts = hosts
self.port = int(port) if port else 19530
@singleton
@ProviderManager.register_service_provider
class KubernetesProvider(object):
NAME = 'Static'
def __init__(self, settings, conn_mgr, **kwargs):
self.conn_mgr = conn_mgr
self.hosts = [socket.gethostbyname(host) for host in settings.hosts]
self.port = settings.port
def start(self):
for host in self.hosts:
self.add_pod(host, host)
def stop(self):
for host in self.hosts:
self.delete_pod(host)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
def delete_pod(self, name):
self.conn_mgr.unregister(name)
[tool:pytest]
testpaths = mishards
log_cli=true
log_cli_level=info
version: "2.3"
services:
milvus:
runtime: nvidia
restart: always
image: registry.zilliz.com/milvus/engine:branch-0.5.0-release-4316de
# ports:
# - "0.0.0.0:19530:19530"
volumes:
- /tmp/milvus/db:/opt/milvus/db
jaeger:
restart: always
image: jaegertracing/all-in-one:1.14
ports:
- "0.0.0.0:5775:5775/udp"
- "0.0.0.0:16686:16686"
- "0.0.0.0:9441:9441"
environment:
COLLECTOR_ZIPKIN_HTTP_PORT: 9411
mishards:
restart: always
image: registry.zilliz.com/milvus/mishards:v0.0.4
ports:
- "0.0.0.0:19530:19531"
- "0.0.0.0:19532:19532"
volumes:
- /tmp/milvus/db:/tmp/milvus/db
# - /tmp/mishards_env:/source/mishards/.env
command: ["python", "mishards/main.py"]
environment:
FROM_EXAMPLE: 'true'
DEBUG: 'true'
SERVER_PORT: 19531
WOSERVER: tcp://milvus:19530
SD_STATIC_HOSTS: milvus
TRACING_TYPE: jaeger
TRACING_SERVICE_NAME: mishards-demo
TRACING_REPORTING_HOST: jaeger
TRACING_REPORTING_PORT: 5775
depends_on:
- milvus
- jaeger
from contextlib import contextmanager
def empty_server_interceptor_decorator(target_server, interceptor):
return target_server
@contextmanager
def EmptySpan(*args, **kwargs):
yield None
return
class Tracer:
def __init__(self,
tracer=None,
interceptor=None,
server_decorator=empty_server_interceptor_decorator):
self.tracer = tracer
self.interceptor = interceptor
self.server_decorator = server_decorator
def decorate(self, server):
return self.server_decorator(server, self.interceptor)
@property
def empty(self):
return self.tracer is None
def close(self):
self.tracer and self.tracer.close()
def start_span(self,
operation_name=None,
child_of=None,
references=None,
tags=None,
start_time=None,
ignore_active_span=False):
if self.empty:
return EmptySpan()
return self.tracer.start_span(operation_name, child_of, references,
tags, start_time, ignore_active_span)
import logging
from jaeger_client import Config
from grpc_opentracing.grpcext import intercept_server
from grpc_opentracing import open_tracing_server_interceptor
from tracing import (Tracer, empty_server_interceptor_decorator)
logger = logging.getLogger(__name__)
class TracerFactory:
@classmethod
def new_tracer(cls,
tracer_type,
tracer_config,
span_decorator=None,
**kwargs):
if not tracer_type:
return Tracer()
config = tracer_config.TRACING_CONFIG
service_name = tracer_config.TRACING_SERVICE_NAME
validate = tracer_config.TRACING_VALIDATE
# if not tracer_type:
# tracer_type = 'jaeger'
# config = tracer_config.DEFAULT_TRACING_CONFIG
if tracer_type.lower() == 'jaeger':
config = Config(config=config,
service_name=service_name,
validate=validate)
tracer = config.initialize_tracer()
tracer_interceptor = open_tracing_server_interceptor(
tracer,
log_payloads=tracer_config.TRACING_LOG_PAYLOAD,
span_decorator=span_decorator)
return Tracer(tracer, tracer_interceptor, intercept_server)
assert False, 'Unsupported tracer type: {}'.format(tracer_type)
from functools import wraps
def singleton(cls):
instances = {}
@wraps(cls)
def getinstance(*args, **kw):
if cls not in instances:
instances[cls] = cls(*args, **kw)
return instances[cls]
return getinstance
import os
import datetime
from pytz import timezone
from logging import Filter
import logging.config
class InfoFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.INFO
class DebugFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.DEBUG
class WarnFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.WARN
class ErrorFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.ERROR
class CriticalFilter(logging.Filter):
def filter(self, rec):
return rec.levelno == logging.CRITICAL
COLORS = {
'HEADER': '\033[95m',
'INFO': '\033[92m',
'DEBUG': '\033[94m',
'WARNING': '\033[93m',
'ERROR': '\033[95m',
'CRITICAL': '\033[91m',
'ENDC': '\033[0m',
}
class ColorFulFormatColMixin:
def format_col(self, message_str, level_name):
if level_name in COLORS.keys():
message_str = COLORS.get(level_name) + message_str + COLORS.get(
'ENDC')
return message_str
class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin):
def format(self, record):
message_str = super(ColorfulFormatter, self).format(record)
return self.format_col(message_str, level_name=record.levelname)
def config(log_level, log_path, name, tz='UTC'):
def build_log_file(level, log_path, name, tz):
utc_now = datetime.datetime.utcnow()
utc_tz = timezone('UTC')
local_tz = timezone(tz)
tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz)
return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"),
level)
if not os.path.exists(log_path):
os.makedirs(log_path)
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'default': {
'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)',
},
'colorful_console': {
'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)',
'()': ColorfulFormatter,
},
},
'filters': {
'InfoFilter': {
'()': InfoFilter,
},
'DebugFilter': {
'()': DebugFilter,
},
'WarnFilter': {
'()': WarnFilter,
},
'ErrorFilter': {
'()': ErrorFilter,
},
'CriticalFilter': {
'()': CriticalFilter,
},
},
'handlers': {
'milvus_celery_console': {
'class': 'logging.StreamHandler',
'formatter': 'colorful_console',
},
'milvus_debug_file': {
'level': 'DEBUG',
'filters': ['DebugFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('debug', log_path, name, tz)
},
'milvus_info_file': {
'level': 'INFO',
'filters': ['InfoFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('info', log_path, name, tz)
},
'milvus_warn_file': {
'level': 'WARN',
'filters': ['WarnFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('warn', log_path, name, tz)
},
'milvus_error_file': {
'level': 'ERROR',
'filters': ['ErrorFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('error', log_path, name, tz)
},
'milvus_critical_file': {
'level': 'CRITICAL',
'filters': ['CriticalFilter'],
'class': 'logging.handlers.RotatingFileHandler',
'formatter': 'default',
'filename': build_log_file('critical', log_path, name, tz)
},
},
'loggers': {
'': {
'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file',
'milvus_error_file', 'milvus_critical_file'],
'level': log_level,
'propagate': False
},
},
'propagate': False,
}
logging.config.dictConfig(LOGGING)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册