提交 c1fbc4f8 编写于 作者: C cai.zhang 提交者: yefu.chen

Open the tests for insert

Signed-off-by: Ncai.zhang <cai.zhang@zilliz.com>
上级 6b74d822
......@@ -93,8 +93,8 @@ endif ()
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
if (KNOWHERE_BUILD_TESTS)
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
#add_subdirectory(unittest)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS")
add_subdirectory(unittest)
endif ()
config_summary()
......@@ -143,7 +143,6 @@ endif ()
target_link_libraries(
knowhere
milvus_utils
${depend_libs}
)
......
......@@ -12,4 +12,4 @@ set(MILVUS_QUERY_SRCS
BruteForceSearch.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
target_link_libraries(milvus_query milvus_proto milvus_utils)
target_link_libraries(milvus_query milvus_proto)
......@@ -24,22 +24,18 @@ namespace milvus::query {
static std::unique_ptr<VectorPlanNode>
ParseVecNode(Plan* plan, const Json& out_body) {
Assert(out_body.is_object());
// TODO add binary info
auto vec_node = std::make_unique<FloatVectorANNS>();
Assert(out_body.size() == 1);
auto iter = out_body.begin();
std::string field_name = iter.key();
auto& vec_info = iter.value();
Assert(vec_info.is_object());
auto topK = vec_info["topk"];
AssertInfo(topK > 0, "topK must greater than 0");
AssertInfo(topK < 16384, "topK is too large");
vec_node->query_info_.topK_ = topK;
vec_node->query_info_.metric_type_ = vec_info.at("metric_type");
vec_node->query_info_.search_params_ = vec_info.at("params");
vec_node->query_info_.metric_type_ = vec_info["metric_type"];
vec_node->query_info_.search_params_ = vec_info["params"];
vec_node->query_info_.field_id_ = field_name;
vec_node->placeholder_tag_ = vec_info.at("query");
vec_node->placeholder_tag_ = vec_info["query"];
auto tag = vec_node->placeholder_tag_;
AssertInfo(!plan->tag2field_.count(tag), "duplicated placeholder tag");
plan->tag2field_.emplace(tag, field_name);
......@@ -60,8 +56,6 @@ to_lower(const std::string& raw) {
return data;
}
template <class...>
constexpr std::false_type always_false{};
template <typename T>
std::unique_ptr<Expr>
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
......@@ -69,19 +63,11 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
auto data_type = schema[field_name].get_data_type();
expr->data_type_ = data_type;
expr->field_id_ = field_name;
Assert(body.is_object());
for (auto& item : body.items()) {
auto op_name = to_lower(item.key());
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
auto op = RangeExpr::mapping_.at(op_name);
if constexpr (std::is_integral_v<T>) {
Assert(item.value().is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(item.value().is_number());
} else {
static_assert(always_false<T>, "unsupported type");
}
T value = item.value();
expr->conditions_.emplace_back(op, value);
}
......@@ -97,10 +83,8 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
PanicInfo("bool is not supported in Range node");
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
}
case DataType::BOOL:
return ParseRangeNodeImpl<bool>(schema, field_name, body);
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
case DataType::INT16:
......@@ -125,17 +109,16 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
nlohmann::json vec_pack;
std::optional<std::unique_ptr<Expr>> predicate;
auto& bool_dsl = dsl.at("bool");
auto& bool_dsl = dsl["bool"];
if (bool_dsl.contains("must")) {
auto& packs = bool_dsl.at("must");
Assert(packs.is_array());
auto& packs = bool_dsl["must"];
for (auto& pack : packs) {
if (pack.contains("vector")) {
auto& out_body = pack.at("vector");
auto& out_body = pack["vector"];
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
} else if (pack.contains("range")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("range");
auto& out_body = pack["range"];
predicate = ParseRangeNode(schema, out_body);
} else {
PanicInfo("unsupported node");
......@@ -143,7 +126,7 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
}
AssertInfo(plan->plan_node_, "vector node not found");
} else if (bool_dsl.contains("vector")) {
auto& out_body = bool_dsl.at("vector");
auto& out_body = bool_dsl["vector"];
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
Assert(plan->plan_node_);
} else {
......
......@@ -17,7 +17,7 @@ add_library(milvus_segcore SHARED
)
target_link_libraries(milvus_segcore
tbb milvus_utils pthread knowhere log milvus_proto
tbb utils pthread knowhere log milvus_proto
dl backtrace
milvus_common
milvus_query
......
......@@ -17,8 +17,8 @@ set(UTILS_FILES
EasyAssert.cpp
)
add_library( milvus_utils STATIC ${UTILS_FILES} )
add_library( utils STATIC ${UTILS_FILES} )
target_link_libraries(milvus_utils
target_link_libraries(utils
libboost_filesystem.a
libboost_system.a)
......@@ -11,20 +11,11 @@
#include <iostream>
#include "EasyAssert.h"
// #define BOOST_STACKTRACE_USE_ADDR2LINE
#define BOOST_STACKTRACE_USE_BACKTRACE
#include <boost/stacktrace.hpp>
#include <sstream>
namespace milvus::impl {
std::string
EasyStackTrace() {
auto stack_info = boost::stacktrace::stacktrace();
std::ostringstream ss;
ss << stack_info;
return ss.str();
}
void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info) {
......@@ -35,15 +26,11 @@ EasyAssertInfo(
if (!extra_info.empty()) {
info += " => " + std::string(extra_info);
}
throw std::runtime_error(info + "\n" + EasyStackTrace());
auto fuck = boost::stacktrace::stacktrace();
std::cout << fuck;
// std::string s = fuck;
// info += ;
throw std::runtime_error(info);
}
}
[[noreturn]] void
ThrowWithTrace(const std::exception& exception) {
auto err_msg = exception.what() + std::string("\n") + EasyStackTrace();
throw std::runtime_error(err_msg);
}
} // namespace milvus::impl
......@@ -11,7 +11,6 @@
#pragma once
#include <string_view>
#include <exception>
#include <stdio.h>
#include <stdlib.h>
......@@ -21,11 +20,7 @@ namespace milvus::impl {
void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
[[noreturn]] void
ThrowWithTrace(const std::exception& exception);
} // namespace milvus::impl
}
#define AssertInfo(expr, info) milvus::impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info))
#define Assert(expr) AssertInfo((expr), "")
......
......@@ -11,13 +11,6 @@
#pragma once
#include "utils/EasyAssert.h"
#define JSON_ASSERT(expr) Assert((expr))
// TODO: dispatch error by type
#define JSON_THROW_USER(e) milvus::impl::ThrowWithTrace((e))
#include "nlohmann/json.hpp"
namespace milvus {
......
......@@ -24,6 +24,5 @@ target_link_libraries(all_tests
knowhere
log
pthread
milvus_utils
)
install (TARGETS all_tests DESTINATION unittest)
......@@ -107,83 +107,6 @@ TEST(Expr, Range) {
std::cout << out.dump(4);
}
TEST(Expr, InvalidRange) {
SUCCEED();
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"range": {
"age": {
"GT": 1,
"LT": "100"
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
}
TEST(Expr, InvalidDSL) {
SUCCEED();
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
std::string dsl_string = R"(
{
"float": {
"must": [
{
"range": {
"age": {
"GT": 1,
"LT": 100
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
}
TEST(Expr, ShowExecutor) {
using namespace milvus::query;
using namespace milvus::segcore;
......
......@@ -4,26 +4,30 @@ import (
"context"
"sync"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
type Cache interface {
Hit(collectionName string) bool
Get(collectionName string) (*servicepb.CollectionDescription, error)
Sync(collectionName string) error
Update(collectionName string, desc *servicepb.CollectionDescription) error
Update(collectionName string) error
Remove(collectionName string) error
}
var globalMetaCache Cache
type SimpleMetaCache struct {
mu sync.RWMutex
metas map[string]*servicepb.CollectionDescription // collection name to schema
ctx context.Context
proxyInstance *Proxy
mu sync.RWMutex
proxyID UniqueID
metas map[string]*servicepb.CollectionDescription // collection name to schema
masterClient masterpb.MasterClient
reqIDAllocator *allocator.IDAllocator
tsoAllocator *allocator.TimestampAllocator
ctx context.Context
}
func (metaCache *SimpleMetaCache) Hit(collectionName string) bool {
......@@ -43,34 +47,58 @@ func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.Collect
return schema, nil
}
func (metaCache *SimpleMetaCache) Sync(collectionName string) error {
dct := &DescribeCollectionTask{
Condition: NewTaskCondition(metaCache.ctx),
DescribeCollectionRequest: internalpb.DescribeCollectionRequest{
MsgType: internalpb.MsgType_kDescribeCollection,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
},
func (metaCache *SimpleMetaCache) Update(collectionName string) error {
reqID, err := metaCache.reqIDAllocator.AllocOne()
if err != nil {
return err
}
ts, err := metaCache.tsoAllocator.AllocOne()
if err != nil {
return err
}
hasCollectionReq := &internalpb.HasCollectionRequest{
MsgType: internalpb.MsgType_kHasCollection,
ReqID: reqID,
Timestamp: ts,
ProxyID: metaCache.proxyID,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
},
masterClient: metaCache.proxyInstance.masterClient,
}
var cancel func()
dct.ctx, cancel = context.WithTimeout(metaCache.ctx, reqTimeoutInterval)
defer cancel()
err := metaCache.proxyInstance.sched.DdQueue.Enqueue(dct)
has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq)
if err != nil {
return err
}
if !has.Value {
return errors.New("collection " + collectionName + " not exists")
}
return dct.WaitToFinish()
}
reqID, err = metaCache.reqIDAllocator.AllocOne()
if err != nil {
return err
}
ts, err = metaCache.tsoAllocator.AllocOne()
if err != nil {
return err
}
req := &internalpb.DescribeCollectionRequest{
MsgType: internalpb.MsgType_kDescribeCollection,
ReqID: reqID,
Timestamp: ts,
ProxyID: metaCache.proxyID,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
},
}
resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req)
if err != nil {
return err
}
func (metaCache *SimpleMetaCache) Update(collectionName string, desc *servicepb.CollectionDescription) error {
metaCache.mu.Lock()
defer metaCache.mu.Unlock()
metaCache.metas[collectionName] = resp
metaCache.metas[collectionName] = desc
return nil
}
......@@ -87,14 +115,23 @@ func (metaCache *SimpleMetaCache) Remove(collectionName string) error {
return nil
}
func newSimpleMetaCache(ctx context.Context, proxyInstance *Proxy) *SimpleMetaCache {
func newSimpleMetaCache(ctx context.Context,
mCli masterpb.MasterClient,
idAllocator *allocator.IDAllocator,
tsoAllocator *allocator.TimestampAllocator) *SimpleMetaCache {
return &SimpleMetaCache{
metas: make(map[string]*servicepb.CollectionDescription),
proxyInstance: proxyInstance,
ctx: ctx,
metas: make(map[string]*servicepb.CollectionDescription),
masterClient: mCli,
reqIDAllocator: idAllocator,
tsoAllocator: tsoAllocator,
proxyID: Params.ProxyID(),
ctx: ctx,
}
}
func initGlobalMetaCache(ctx context.Context, proxyInstance *Proxy) {
globalMetaCache = newSimpleMetaCache(ctx, proxyInstance)
func initGlobalMetaCache(ctx context.Context,
mCli masterpb.MasterClient,
idAllocator *allocator.IDAllocator,
tsoAllocator *allocator.TimestampAllocator) {
globalMetaCache = newSimpleMetaCache(ctx, mCli, idAllocator, tsoAllocator)
}
......@@ -109,7 +109,7 @@ func (p *Proxy) startProxy() error {
if err != nil {
return err
}
initGlobalMetaCache(p.proxyLoopCtx, p)
initGlobalMetaCache(p.proxyLoopCtx, p.masterClient, p.idAllocator, p.tsoAllocator)
p.manipulationMsgStream.Start()
p.queryMsgStream.Start()
p.sched.Start()
......
......@@ -91,7 +91,7 @@ func (it *InsertTask) PreExecute() error {
func (it *InsertTask) Execute() error {
collectionName := it.BaseInsertTask.CollectionName
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Sync(collectionName)
err := globalMetaCache.Update(collectionName)
if err != nil {
return err
}
......@@ -352,7 +352,7 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
func (qt *QueryTask) PreExecute() error {
collectionName := qt.query.CollectionName
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Sync(collectionName)
err := globalMetaCache.Update(collectionName)
if err != nil {
return err
}
......@@ -605,9 +605,14 @@ func (dct *DescribeCollectionTask) PreExecute() error {
}
func (dct *DescribeCollectionTask) Execute() error {
if !globalMetaCache.Hit(dct.CollectionName.CollectionName) {
err := globalMetaCache.Update(dct.CollectionName.CollectionName)
if err != nil {
return err
}
}
var err error
dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, &dct.DescribeCollectionRequest)
globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result)
dct.result, err = globalMetaCache.Get(dct.CollectionName.CollectionName)
return err
}
......
......@@ -29,7 +29,7 @@ func createPlan(col Collection, dsl string) (*Plan, error) {
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("Create plan failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
var newPlan = &Plan{cPlan: cPlan}
......@@ -60,7 +60,7 @@ func parserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) (*PlaceholderGro
if errorCode != 0 {
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return nil, errors.New("Parser placeholder group failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
return nil, errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
}
var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup}
......
......@@ -139,7 +139,7 @@ func (ss *searchService) receiveSearchMsg() {
err := ss.search(msg)
if err != nil {
log.Println(err)
err = ss.publishFailedSearchResult(msg, err.Error())
err = ss.publishFailedSearchResult(msg)
if err != nil {
log.Println("publish FailedSearchResult failed, error message: ", err)
}
......@@ -191,7 +191,7 @@ func (ss *searchService) doUnsolvedMsgSearch() {
err := ss.search(msg)
if err != nil {
log.Println(err)
err = ss.publishFailedSearchResult(msg, err.Error())
err = ss.publishFailedSearchResult(msg)
if err != nil {
log.Println("publish FailedSearchResult failed, error message: ", err)
}
......@@ -346,7 +346,7 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
return nil
}
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error {
func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg) error {
msgPack := msgstream.MsgPack{}
searchMsg, ok := msg.(*msgstream.SearchMsg)
if !ok {
......@@ -354,7 +354,7 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg s
}
var results = internalpb.SearchResult{
MsgType: internalpb.MsgType_kSearchResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: errMsg},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
ReqID: searchMsg.ReqID,
ProxyID: searchMsg.ProxyID,
QueryNodeID: searchMsg.ProxyID,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册