提交 045fc3d8 编写于 作者: Z zhenshan.cao 提交者: yefu.chen

Fix segment fault

Signed-off-by: Nzhenshan.cao <zhenshan.cao@zilliz.com>
上级 0a16a9a6
......@@ -56,8 +56,8 @@ verifiers: cppcheck fmt lint ruleguard
# Builds various components locally.
build-go:
@echo "Building each component's binary to './'"
@echo "Building reader ..."
@mkdir -p $(INSTALL_PATH) && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/reader $(PWD)/cmd/reader/reader.go 1>/dev/null
@echo "Building query node ..."
@mkdir -p $(INSTALL_PATH) && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/querynode $(PWD)/cmd/querynode/query_node.go 1>/dev/null
@echo "Building master ..."
@mkdir -p $(INSTALL_PATH) && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null
@echo "Building proxy ..."
......@@ -72,7 +72,7 @@ build-cpp-with-unittest:
# Runs the tests.
unittest: test-cpp test-go
#TODO: proxy master reader writer's unittest
#TODO: proxy master query node writer's unittest
test-go:
@echo "Running go unittests..."
@(env bash $(PWD)/scripts/run_go_unittest.sh)
......@@ -83,14 +83,14 @@ test-cpp: build-cpp-with-unittest
#TODO: build each component to docker
docker: verifiers
@echo "Building reader docker image '$(TAG)'"
@echo "Building query node docker image '$(TAG)'"
@echo "Building proxy docker image '$(TAG)'"
@echo "Building master docker image '$(TAG)'"
# Builds each component and installs it to $GOPATH/bin.
install: all
@echo "Installing binary to './bin'"
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/reader $(GOPATH)/bin/reader
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/querynode $(GOPATH)/bin/querynode
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/master $(GOPATH)/bin/master
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/proxy $(GOPATH)/bin/proxy
@mkdir -p $(LIBRARY_PATH) && cp -f $(PWD)/internal/core/output/lib/* $(LIBRARY_PATH)
......@@ -100,6 +100,6 @@ clean:
@echo "Cleaning up all the generated files"
@find . -name '*.test' | xargs rm -fv
@find . -name '*~' | xargs rm -fv
@rm -rvf reader
@rm -rvf querynode
@rm -rvf master
@rm -rvf proxy
......@@ -6,14 +6,14 @@ import (
"os/signal"
"syscall"
"github.com/zilliztech/milvus-distributed/internal/reader"
"github.com/zilliztech/milvus-distributed/internal/querynode"
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
reader.Init()
querynode.Init()
sc := make(chan os.Signal, 1)
signal.Notify(sc,
......@@ -28,7 +28,7 @@ func main() {
cancel()
}()
reader.StartQueryNode(ctx)
querynode.StartQueryNode(ctx)
switch sig {
case syscall.SIGTERM:
......
......@@ -30,10 +30,10 @@ msgChannel:
queryNodeSubNamePrefix: "queryNode"
writeNodeSubNamePrefix: "writeNode"
# default channel range [0, 0]
# default channel range [0, 1)
channelRange:
insert: [0, 15]
delete: [0, 15]
k2s: [0, 15]
search: [0, 0]
insert: [0, 1]
delete: [0, 1]
k2s: [0, 1]
search: [0, 1]
searchResult: [0, 1]
\ No newline at end of file
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
common:
defaultPartitionTag: _default
......@@ -20,4 +20,6 @@ master:
minIDAssignCnt: 1024
maxIDAssignCnt: 16384
# old name: segmentExpireDuration: 2000
IDAssignExpiration: 2000 # ms
\ No newline at end of file
IDAssignExpiration: 2000 # ms
maxPartitionNum: 4096
\ No newline at end of file
......@@ -25,4 +25,7 @@ proxy:
pulsarBufSize: 1024 # pulsar chan buffer size
timeTick:
bufSize: 512
\ No newline at end of file
bufSize: 512
maxNameLength: 255
maxFieldNum: 64
\ No newline at end of file
......@@ -9,7 +9,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
reader:
queryNode:
stats:
publishInterval: 1000 # milliseconds
......@@ -19,10 +19,6 @@ reader:
maxParallelism: 1024
msgStream:
dm: # TODO: rm dm
#streamBufSize: 1024 # msgPack chan buffer size
recvBufSize: 1024 # msgPack chan buffer size
pulsarBufSize: 1024 # pulsar chan buffer size
insert:
#streamBufSize: 1024 # msgPack chan buffer size
recvBufSize: 1024 # msgPack chan buffer size
......
......@@ -11,9 +11,9 @@
nodeID: # will be deprecated after v0.2
proxyIDList: [1, 2]
queryNodeIDList: [3, 4]
writeNodeIDList: [5, 6]
proxyIDList: [1]
queryNodeIDList: [2]
writeNodeIDList: [3]
etcd:
address: localhost
......
......@@ -380,18 +380,13 @@ func (segMgr *SegmentManager) AssignSegmentID(segIDReq []*internalpb.SegIDReques
// "/msg_stream/insert"
message SysConfigRequest {
MsgType msg_type = 1;
int64 reqID = 2;
int64 proxyID = 3;
uint64 timestamp = 4;
repeated string keys = 5;
repeated string key_prefixes = 6;
repeated string keys = 1;
repeated string key_prefixes = 2;
}
message SysConfigResponse {
common.Status status = 1;
repeated string keys = 2;
repeated string values = 3;
repeated string keys = 1;
repeated string values = 2;
}
```
......@@ -399,11 +394,12 @@ message SysConfigResponse {
```go
type SysConfig struct {
kv *kv.EtcdKV
etcdKV *etcd
etcdPathPrefix string
}
func (conf *SysConfig) InitFromFile(filePath string) (error)
func (conf *SysConfig) GetByPrefix(keyPrefix string) (keys []string, values []string, err error)
func (conf *SysConfig) GetByPrefix(keyPrefix string) ([]string, error)
func (conf *SysConfig) Get(keys []string) ([]string, error)
```
......
......@@ -7,6 +7,7 @@ require (
github.com/apache/pulsar-client-go v0.1.1
github.com/aws/aws-sdk-go v1.30.8
github.com/coreos/etcd v3.3.25+incompatible // indirect
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548
github.com/frankban/quicktest v1.10.2 // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
......
......@@ -65,6 +65,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbp
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso=
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
......@@ -329,6 +330,7 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx
github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 h1:/NRJ5vAYoqz+7sG51ubIDHXeWO8DlTSrToPu6q11ziA=
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
......
......@@ -57,7 +57,7 @@ type segRequest struct {
count uint32
colName string
partition string
segID UniqueID
segInfo map[UniqueID]uint32
channelID int32
}
......
package allocator
import (
"container/list"
"context"
"fmt"
"log"
"sort"
"time"
"github.com/cznic/mathutil"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/errors"
......@@ -18,7 +22,10 @@ const (
)
type assignInfo struct {
internalpb.SegIDAssignment
collName string
partitionTag string
channelID int32
segInfo map[UniqueID]uint32 // segmentID->count map
expireTime time.Time
lastInsertTime time.Time
}
......@@ -32,12 +39,16 @@ func (info *assignInfo) IsActive(now time.Time) bool {
}
func (info *assignInfo) IsEnough(count uint32) bool {
return info.Count >= count
total := uint32(0)
for _, count := range info.segInfo {
total += count
}
return total >= count
}
type SegIDAssigner struct {
Allocator
assignInfos map[string][]*assignInfo // collectionName -> [] *assignInfo
assignInfos map[string]*list.List // collectionName -> *list.List
segReqs []*internalpb.SegIDRequest
canDoReqs []request
}
......@@ -50,11 +61,8 @@ func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, e
cancel: cancel,
masterAddress: masterAddr,
countPerRPC: SegCountPerRPC,
//toDoReqs: []request,
},
assignInfos: make(map[string][]*assignInfo),
//segReqs: make([]*internalpb.SegIDRequest, maxConcurrentRequests),
//canDoReqs: make([]request, maxConcurrentRequests),
assignInfos: make(map[string]*list.List),
}
sa.tChan = &ticker{
updateInterval: time.Second,
......@@ -67,16 +75,17 @@ func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, e
func (sa *SegIDAssigner) collectExpired() {
now := time.Now()
for _, colInfos := range sa.assignInfos {
for _, assign := range colInfos {
for _, info := range sa.assignInfos {
for e := info.Front(); e != nil; e = e.Next() {
assign := e.Value.(*assignInfo)
if !assign.IsActive(now) || !assign.IsExpired(now) {
continue
}
sa.segReqs = append(sa.segReqs, &internalpb.SegIDRequest{
ChannelID: assign.ChannelID,
ChannelID: assign.channelID,
Count: sa.countPerRPC,
CollName: assign.CollName,
PartitionTag: assign.PartitionTag,
CollName: assign.collName,
PartitionTag: assign.partitionTag,
})
}
}
......@@ -88,7 +97,6 @@ func (sa *SegIDAssigner) checkToDoReqs() {
}
now := time.Now()
for _, req := range sa.toDoReqs {
fmt.Println("DDDDD????", req)
segRequest := req.(*segRequest)
assign := sa.getAssign(segRequest.colName, segRequest.partition, segRequest.channelID)
if assign == nil || assign.IsExpired(now) || !assign.IsEnough(segRequest.count) {
......@@ -102,13 +110,36 @@ func (sa *SegIDAssigner) checkToDoReqs() {
}
}
func (sa *SegIDAssigner) removeSegInfo(colName, partition string, channelID int32) {
assignInfos, ok := sa.assignInfos[colName]
if !ok {
return
}
cnt := assignInfos.Len()
if cnt == 0 {
return
}
for e := assignInfos.Front(); e != nil; e = e.Next() {
assign := e.Value.(*assignInfo)
if assign.partitionTag != partition || assign.channelID != channelID {
continue
}
assignInfos.Remove(e)
}
}
func (sa *SegIDAssigner) getAssign(colName, partition string, channelID int32) *assignInfo {
colInfos, ok := sa.assignInfos[colName]
assignInfos, ok := sa.assignInfos[colName]
if !ok {
return nil
}
for _, info := range colInfos {
if info.PartitionTag != partition || info.ChannelID != channelID {
for e := assignInfos.Front(); e != nil; e = e.Next() {
info := e.Value.(*assignInfo)
if info.partitionTag != partition || info.channelID != channelID {
continue
}
return info
......@@ -151,19 +182,29 @@ func (sa *SegIDAssigner) syncSegments() {
now := time.Now()
expiredTime := now.Add(time.Millisecond * time.Duration(resp.ExpireDuration))
for _, info := range resp.PerChannelAssignment {
sa.removeSegInfo(info.CollName, info.PartitionTag, info.ChannelID)
}
for _, info := range resp.PerChannelAssignment {
assign := sa.getAssign(info.CollName, info.PartitionTag, info.ChannelID)
if assign == nil {
colInfos := sa.assignInfos[info.CollName]
colInfos, ok := sa.assignInfos[info.CollName]
if !ok {
colInfos = list.New()
}
segInfo := make(map[UniqueID]uint32)
segInfo[info.SegID] = info.Count
newAssign := &assignInfo{
SegIDAssignment: *info,
expireTime: expiredTime,
lastInsertTime: now,
collName: info.CollName,
partitionTag: info.PartitionTag,
channelID: info.ChannelID,
segInfo: segInfo,
}
colInfos = append(colInfos, newAssign)
colInfos.PushBack(newAssign)
sa.assignInfos[info.CollName] = colInfos
} else {
assign.SegIDAssignment = *info
assign.segInfo[info.SegID] = info.Count
assign.expireTime = expiredTime
assign.lastInsertTime = now
}
......@@ -181,13 +222,38 @@ func (sa *SegIDAssigner) processFunc(req request) error {
if assign == nil {
return errors.New("Failed to GetSegmentID")
}
segRequest.segID = assign.SegID
assign.Count -= segRequest.count
keys := make([]UniqueID, len(assign.segInfo))
i := 0
for key := range assign.segInfo {
keys[i] = key
i++
}
reqCount := segRequest.count
resultSegInfo := make(map[UniqueID]uint32)
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
for _, key := range keys {
if reqCount <= 0 {
break
}
cur := assign.segInfo[key]
minCnt := mathutil.MinUint32(cur, reqCount)
resultSegInfo[key] = minCnt
cur -= minCnt
reqCount -= minCnt
if cur <= 0 {
delete(assign.segInfo, key)
} else {
assign.segInfo[key] = cur
}
}
segRequest.segInfo = resultSegInfo
fmt.Println("process segmentID")
return nil
}
func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (UniqueID, error) {
func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (map[UniqueID]uint32, error) {
req := &segRequest{
baseRequest: baseRequest{done: make(chan error), valid: false},
colName: colName,
......@@ -199,7 +265,7 @@ func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32
req.Wait()
if !req.IsValid() {
return 0, errors.New("GetSegmentID Failed")
return nil, errors.New("GetSegmentID Failed")
}
return req.segID, nil
return req.segInfo, nil
}
......@@ -13,7 +13,7 @@ import (
type Timestamp = typeutil.Timestamp
const (
tsCountPerRPC = 2 << 18 * 10
tsCountPerRPC = 2 << 15
)
type TimestampAllocator struct {
......@@ -37,6 +37,7 @@ func NewTimestampAllocator(ctx context.Context, masterAddr string) (*TimestampAl
}
a.Allocator.syncFunc = a.syncTs
a.Allocator.processFunc = a.processFunc
a.Allocator.checkFunc = a.checkFunc
return a, nil
}
......
......@@ -65,7 +65,8 @@ StructuredIndexFlat<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size(), true);
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
for (size_t i = 0; i < n; ++i) {
for (const auto& index : data_) {
if (index->a_ == *(values + i)) {
......
......@@ -120,7 +120,8 @@ StructuredIndexSort<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size(), true);
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
......
......@@ -130,13 +130,7 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
}
case OpType::NotEqual: {
auto index_func = [val](Index* index) {
// Note: index->NotIn() is buggy, investigating
// this is a workaround
auto res = index->In(1, &val);
*res = ~std::move(*res);
return res;
};
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); });
}
......
......@@ -13,9 +13,9 @@
extern "C" {
#endif
#include "segcore/collection_c.h"
#include <stdbool.h>
#include <stdint.h>
#include "segcore/collection_c.h"
typedef void* CPlan;
typedef void* CPlaceholderGroup;
......
......@@ -9,11 +9,183 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <vector>
#include <utils/EasyAssert.h>
#include "segcore/reduce_c.h"
#include "segcore/Reduce.h"
#include "utils/Types.h"
#include "pb/service_msg.pb.h"
using SearchResult = milvus::engine::QueryResult;
int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids) {
auto status = milvus::segcore::merge_into(num_queries, topk, distances, uids, new_distances, new_uids);
return status.code();
}
struct MarshaledHitsPeerGroup {
std::vector<std::string> hits_;
std::vector<int64_t> blob_length_;
};
struct MarshaledHits {
explicit MarshaledHits(int64_t num_group) {
marshaled_hits_.resize(num_group);
}
int
get_num_group() {
return marshaled_hits_.size();
}
std::vector<MarshaledHitsPeerGroup> marshaled_hits_;
};
void
DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
auto hits = (MarshaledHits*)c_marshaled_hits;
delete hits;
}
struct SearchResultPair {
uint64_t id_;
float distance_;
int64_t segment_id_;
SearchResultPair(uint64_t id, float distance, int64_t segment_id)
: id_(id), distance_(distance), segment_id_(segment_id) {
}
bool
operator<(const SearchResultPair& pair) const {
return (distance_ < pair.distance_);
}
};
void
GetResultData(std::vector<SearchResult*>& search_results,
SearchResult& final_result,
int64_t query_offset,
int64_t topk) {
auto num_segments = search_results.size();
std::map<int, int> iter_loc_peer_result;
std::vector<SearchResultPair> result_pairs;
for (int j = 0; j < num_segments; ++j) {
auto id = search_results[j]->result_ids_[query_offset];
auto distance = search_results[j]->result_distances_[query_offset];
result_pairs.push_back(SearchResultPair(id, distance, j));
iter_loc_peer_result[j] = query_offset;
}
std::sort(result_pairs.begin(), result_pairs.end());
final_result.result_ids_.push_back(result_pairs[0].id_);
final_result.result_distances_.push_back(result_pairs[0].distance_);
for (int i = 1; i < topk; ++i) {
auto segment_id = result_pairs[0].segment_id_;
auto query_offset = ++(iter_loc_peer_result[segment_id]);
auto id = search_results[segment_id]->result_ids_[query_offset];
auto distance = search_results[segment_id]->result_distances_[query_offset];
result_pairs[0] = SearchResultPair(id, distance, segment_id);
std::sort(result_pairs.begin(), result_pairs.end());
final_result.result_ids_.push_back(result_pairs[0].id_);
final_result.result_distances_.push_back(result_pairs[0].distance_);
}
}
CQueryResult
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments) {
std::vector<SearchResult*> search_results;
for (int i = 0; i < num_segments; ++i) {
search_results.push_back((SearchResult*)query_results[i]);
}
auto topk = search_results[0]->topK_;
auto num_queries = search_results[0]->num_queries_;
auto final_result = std::make_unique<SearchResult>();
int64_t query_offset = 0;
for (int j = 0; j < num_queries; ++j) {
GetResultData(search_results, *final_result, query_offset, topk);
query_offset += topk;
}
return (CQueryResult)final_result.release();
}
CMarshaledHits
ReorganizeQueryResults(CQueryResult c_query_result,
CPlan c_plan,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups) {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto search_result = (milvus::engine::QueryResult*)c_query_result;
auto& result_ids = search_result->result_ids_;
auto& result_distances = search_result->result_distances_;
auto topk = GetTopK(c_plan);
int64_t queries_offset = 0;
for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries; j++) {
auto index = topk * queries_offset++;
milvus::proto::service::Hits hits;
for (int k = index; k < index + topk; k++) {
hits.add_ids(result_ids[k]);
hits.add_scores(result_distances[k]);
}
auto blob = hits.SerializeAsString();
hits_peer_group.hits_.push_back(blob);
hits_peer_group.blob_length_.push_back(blob.size());
}
}
return (CMarshaledHits)marshaledHits.release();
}
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits) {
int64_t total_size = 0;
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto num_group = marshaled_hits->get_num_group();
for (int i = 0; i < num_group; i++) {
auto& length_vector = marshaled_hits->marshaled_hits_[i].blob_length_;
for (int j = 0; j < length_vector.size(); j++) {
total_size += length_vector[j];
}
}
return total_size;
}
void
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits) {
auto byte_hits = (char*)hits;
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto num_group = marshaled_hits->get_num_group();
int offset = 0;
for (int i = 0; i < num_group; i++) {
auto& hits = marshaled_hits->marshaled_hits_[i];
auto num_queries = hits.hits_.size();
for (int j = 0; j < num_queries; j++) {
auto blob_size = hits.blob_length_[j];
memcpy(byte_hits + offset, hits.hits_[j].data(), blob_size);
offset += blob_size;
}
}
}
int64_t
GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_;
return hits.size();
}
void
GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_;
for (int i = 0; i < blob_lens.size(); i++) {
hit_size_peer_query[i] = blob_lens[i];
}
}
......@@ -15,10 +15,37 @@ extern "C" {
#include <stdbool.h>
#include <stdint.h>
#include "segcore/segment_c.h"
typedef void* CMarshaledHits;
void
DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
CQueryResult
ReduceQueryResults(CQueryResult* query_results, int64_t num_segments);
CMarshaledHits
ReorganizeQueryResults(CQueryResult query_result,
CPlan c_plan,
CPlaceholderGroup* c_placeholder_groups,
int64_t num_groups);
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
void
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits);
int64_t
GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
void
GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
#ifdef __cplusplus
}
#endif
......@@ -18,6 +18,7 @@
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
#include <cstdint>
#include <boost/concept_check.hpp>
CSegmentBase
NewSegment(CCollection collection, uint64_t segment_id) {
......@@ -39,9 +40,15 @@ DeleteSegment(CSegmentBase segment) {
delete s;
}
void
DeleteQueryResult(CQueryResult query_result) {
auto res = (milvus::segcore::QueryResult*)query_result;
delete res;
}
//////////////////////////////////////////////////////////////////
int
CStatus
Insert(CSegmentBase c_segment,
int64_t reserved_offset,
int64_t size,
......@@ -57,11 +64,22 @@ Insert(CSegmentBase c_segment,
dataChunk.sizeof_per_row = sizeof_per_row;
dataChunk.count = count;
auto res = segment->Insert(reserved_offset, size, row_ids, timestamps, dataChunk);
try {
auto res = segment->Insert(reserved_offset, size, row_ids, timestamps, dataChunk);
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
return status;
} catch (std::runtime_error& e) {
auto status = CStatus();
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
return status;
}
// TODO: delete print
// std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl;
return res.code();
}
int64_t
......@@ -73,13 +91,24 @@ PreInsert(CSegmentBase c_segment, int64_t size) {
return segment->PreInsert(size);
}
int
CStatus
Delete(
CSegmentBase c_segment, int64_t reserved_offset, int64_t size, const int64_t* row_ids, const uint64_t* timestamps) {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto res = segment->Delete(reserved_offset, size, row_ids, timestamps);
return res.code();
try {
auto res = segment->Delete(reserved_offset, size, row_ids, timestamps);
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
return status;
} catch (std::runtime_error& e) {
auto status = CStatus();
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
return status;
}
}
int64_t
......@@ -91,30 +120,39 @@ PreDelete(CSegmentBase c_segment, int64_t size) {
return segment->PreDelete(size);
}
int
CStatus
Search(CSegmentBase c_segment,
CPlan c_plan,
CPlaceholderGroup* c_placeholder_groups,
uint64_t* timestamps,
int num_groups,
int64_t* result_ids,
float* result_distances) {
CQueryResult* result) {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
std::vector<const milvus::query::PlaceholderGroup*> placeholder_groups;
for (int i = 0; i < num_groups; ++i) {
placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]);
}
milvus::segcore::QueryResult query_result;
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result);
auto query_result = std::make_unique<milvus::segcore::QueryResult>();
auto status = CStatus();
try {
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, *query_result);
status.error_code = Success;
status.error_msg = "";
} catch (std::runtime_error& e) {
status.error_code = UnexpectedException;
status.error_msg = strdup(e.what());
}
*result = query_result.release();
// result_ids and result_distances have been allocated memory in goLang,
// so we don't need to malloc here.
memcpy(result_ids, query_result.result_ids_.data(), query_result.get_row_count() * sizeof(int64_t));
memcpy(result_distances, query_result.result_distances_.data(), query_result.get_row_count() * sizeof(float));
// memcpy(result_ids, query_result.result_ids_.data(), query_result.get_row_count() * sizeof(long int));
// memcpy(result_distances, query_result.result_distances_.data(), query_result.get_row_count() * sizeof(float));
return res.code();
return status;
}
//////////////////////////////////////////////////////////////////
......
......@@ -14,11 +14,24 @@ extern "C" {
#endif
#include <stdbool.h>
#include <stdlib.h>
#include <stdint.h>
#include "segcore/collection_c.h"
#include "segcore/plan_c.h"
#include <stdint.h>
typedef void* CSegmentBase;
typedef void* CQueryResult;
enum ErrorCode {
Success = 0,
UnexpectedException = 1,
};
typedef struct CStatus {
int error_code;
const char* error_msg;
} CStatus;
CSegmentBase
NewSegment(CCollection collection, uint64_t segment_id);
......@@ -26,9 +39,12 @@ NewSegment(CCollection collection, uint64_t segment_id);
void
DeleteSegment(CSegmentBase segment);
void
DeleteQueryResult(CQueryResult query_result);
//////////////////////////////////////////////////////////////////
int
CStatus
Insert(CSegmentBase c_segment,
int64_t reserved_offset,
int64_t size,
......@@ -41,21 +57,20 @@ Insert(CSegmentBase c_segment,
int64_t
PreInsert(CSegmentBase c_segment, int64_t size);
int
CStatus
Delete(
CSegmentBase c_segment, int64_t reserved_offset, int64_t size, const int64_t* row_ids, const uint64_t* timestamps);
int64_t
PreDelete(CSegmentBase c_segment, int64_t size);
int
CStatus
Search(CSegmentBase c_segment,
CPlan plan,
CPlaceholderGroup* placeholder_groups,
uint64_t* timestamps,
int num_groups,
int64_t* result_ids,
float* result_distances);
CQueryResult* result);
//////////////////////////////////////////////////////////////////
......
......@@ -15,7 +15,6 @@
#include <gtest/gtest.h>
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
#include "pb/service_msg.pb.h"
#include "segcore/reduce_c.h"
......@@ -65,7 +64,7 @@ TEST(CApiTest, InsertTest) {
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(res == 0);
assert(res.error_code == Success);
DeleteCollection(collection);
DeleteSegment(segment);
......@@ -82,7 +81,7 @@ TEST(CApiTest, DeleteTest) {
auto offset = PreDelete(segment, 3);
auto del_res = Delete(segment, offset, 3, delete_row_ids, delete_timestamps);
assert(del_res == 0);
assert(del_res.error_code == Success);
DeleteCollection(collection);
DeleteSegment(segment);
......@@ -116,7 +115,7 @@ TEST(CApiTest, SearchTest) {
auto offset = PreInsert(segment, N);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res == 0);
assert(ins_res.error_code == Success);
const char* dsl_string = R"(
{
......@@ -159,105 +158,100 @@ TEST(CApiTest, SearchTest) {
timestamps.clear();
timestamps.push_back(1);
long result_ids[100];
float result_distances[100];
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
assert(sea_res == 0);
CQueryResult search_result;
auto res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, &search_result);
assert(res.error_code == Success);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(search_result);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, BuildIndexTest) {
auto schema_tmp_conf = "";
auto collection = NewCollection(schema_tmp_conf);
auto segment = NewSegment(collection, 0);
std::vector<char> raw_data;
std::vector<uint64_t> timestamps;
std::vector<int64_t> uids;
int N = 10000;
std::default_random_engine e(67);
for (int i = 0; i < N; ++i) {
uids.push_back(100000 + i);
timestamps.push_back(0);
// append vec
float vec[16];
for (auto& x : vec) {
x = e() % 2000 * 0.001 - 1.0;
}
raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
int age = e() % 100;
raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
}
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto offset = PreInsert(segment, N);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res == 0);
// TODO: add index ptr
Close(segment);
// BuildIndex(collection, segment);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
})";
namespace ser = milvus::proto::service;
int num_queries = 10;
int dim = 16;
std::normal_distribution<double> dis(0, 1);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
auto plan = CreatePlan(collection, dsl_string);
auto placeholderGroup = ParsePlaceholderGroup(plan, blob.data(), blob.length());
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
timestamps.clear();
timestamps.push_back(1);
long result_ids[100];
float result_distances[100];
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
assert(sea_res == 0);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteCollection(collection);
DeleteSegment(segment);
}
// TEST(CApiTest, BuildIndexTest) {
// auto schema_tmp_conf = "";
// auto collection = NewCollection(schema_tmp_conf);
// auto segment = NewSegment(collection, 0);
//
// std::vector<char> raw_data;
// std::vector<uint64_t> timestamps;
// std::vector<int64_t> uids;
// int N = 10000;
// std::default_random_engine e(67);
// for (int i = 0; i < N; ++i) {
// uids.push_back(100000 + i);
// timestamps.push_back(0);
// // append vec
// float vec[16];
// for (auto& x : vec) {
// x = e() % 2000 * 0.001 - 1.0;
// }
// raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
// int age = e() % 100;
// raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
// }
//
// auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
//
// auto offset = PreInsert(segment, N);
//
// auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
// assert(ins_res == 0);
//
// // TODO: add index ptr
// Close(segment);
// BuildIndex(collection, segment);
//
// const char* dsl_string = R"(
// {
// "bool": {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10
// }
// }
// }
// })";
//
// namespace ser = milvus::proto::service;
// int num_queries = 10;
// int dim = 16;
// std::normal_distribution<double> dis(0, 1);
// ser::PlaceholderGroup raw_group;
// auto value = raw_group.add_placeholders();
// value->set_tag("$0");
// value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
// for (int i = 0; i < num_queries; ++i) {
// std::vector<float> vec;
// for (int d = 0; d < dim; ++d) {
// vec.push_back(dis(e));
// }
// // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
// value->add_values(vec.data(), vec.size() * sizeof(float));
// }
// auto blob = raw_group.SerializeAsString();
//
// auto plan = CreatePlan(collection, dsl_string);
// auto placeholderGroup = ParsePlaceholderGroup(plan, blob.data(), blob.length());
// std::vector<CPlaceholderGroup> placeholderGroups;
// placeholderGroups.push_back(placeholderGroup);
// timestamps.clear();
// timestamps.push_back(1);
//
// auto search_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1);
//
// DeletePlan(plan);
// DeletePlaceholderGroup(placeholderGroup);
// DeleteQueryResult(search_res);
// DeleteCollection(collection);
// DeleteSegment(segment);
//}
TEST(CApiTest, IsOpenedTest) {
auto schema_tmp_conf = "";
......@@ -315,7 +309,7 @@ TEST(CApiTest, GetMemoryUsageInBytesTest) {
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(res == 0);
assert(res.error_code == Success);
auto memory_usage_size = GetMemoryUsageInBytes(segment);
......@@ -482,7 +476,7 @@ TEST(CApiTest, GetDeletedCountTest) {
auto offset = PreDelete(segment, 3);
auto del_res = Delete(segment, offset, 3, delete_row_ids, delete_timestamps);
assert(del_res == 0);
assert(del_res.error_code == Success);
// TODO: assert(deleted_count == len(delete_row_ids))
auto deleted_count = GetDeletedCount(segment);
......@@ -502,7 +496,7 @@ TEST(CApiTest, GetRowCountTest) {
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto offset = PreInsert(segment, N);
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(res == 0);
assert(res.error_code == Success);
auto row_count = GetRowCount(segment);
assert(row_count == N);
......@@ -552,3 +546,109 @@ TEST(CApiTest, MergeInto) {
ASSERT_EQ(uids[1], 1);
ASSERT_EQ(distance[1], 5);
}
TEST(CApiTest, Reduce) {
auto schema_tmp_conf = "";
auto collection = NewCollection(schema_tmp_conf);
auto segment = NewSegment(collection, 0);
std::vector<char> raw_data;
std::vector<uint64_t> timestamps;
std::vector<int64_t> uids;
int N = 10000;
std::default_random_engine e(67);
for (int i = 0; i < N; ++i) {
uids.push_back(100000 + i);
timestamps.push_back(0);
// append vec
float vec[16];
for (auto& x : vec) {
x = e() % 2000 * 0.001 - 1.0;
}
raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
int age = e() % 100;
raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
}
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto offset = PreInsert(segment, N);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res.error_code == Success);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
})";
namespace ser = milvus::proto::service;
int num_queries = 10;
int dim = 16;
std::normal_distribution<double> dis(0, 1);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
auto plan = CreatePlan(collection, dsl_string);
auto placeholderGroup = ParsePlaceholderGroup(plan, blob.data(), blob.length());
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
timestamps.clear();
timestamps.push_back(1);
std::vector<CQueryResult> results;
CQueryResult res1;
CQueryResult res2;
auto res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, &res1);
assert(res.error_code == Success);
res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, &res2);
assert(res.error_code == Success);
results.push_back(res1);
results.push_back(res2);
auto reduced_search_result = ReduceQueryResults(results.data(), 2);
auto reorganize_search_result = ReorganizeQueryResults(reduced_search_result, plan, placeholderGroups.data(), 1);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
std::vector<char> hits_blob;
hits_blob.resize(hits_blob_size);
GetHitsBlob(reorganize_search_result, hits_blob.data());
assert(hits_blob.data() != nullptr);
auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0);
assert(num_queries_group == 10);
std::vector<int64_t> hit_size_peer_query;
hit_size_peer_query.resize(num_queries_group);
GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data());
assert(hit_size_peer_query[0] > 0);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(res1);
DeleteQueryResult(res2);
DeleteQueryResult(reduced_search_result);
DeleteMarshaledHits(reorganize_search_result);
DeleteCollection(collection);
DeleteSegment(segment);
}
package master
import (
"log"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/kv"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
type getSysConfigsTask struct {
baseTask
configkv *kv.EtcdKV
req *internalpb.SysConfigRequest
keys []string
values []string
}
func (t *getSysConfigsTask) Type() internalpb.MsgType {
if t.req == nil {
log.Printf("null request")
return 0
}
return t.req.MsgType
}
func (t *getSysConfigsTask) Ts() (Timestamp, error) {
if t.req == nil {
return 0, errors.New("null request")
}
return t.req.Timestamp, nil
}
func (t *getSysConfigsTask) Execute() error {
if t.req == nil {
return errors.New("null request")
}
sc := &SysConfig{kv: t.configkv}
keyMap := make(map[string]bool)
// Load configs with prefix
for _, prefix := range t.req.KeyPrefixes {
prefixKeys, prefixVals, err := sc.GetByPrefix(prefix)
if err != nil {
return errors.Errorf("Load configs by prefix wrong: %s", err.Error())
}
t.keys = append(t.keys, prefixKeys...)
t.values = append(t.values, prefixVals...)
}
for _, key := range t.keys {
keyMap[key] = true
}
// Load specific configs
if len(t.req.Keys) > 0 {
// To clean up duplicated keys
cleanKeys := []string{}
for _, key := range t.req.Keys {
if v, ok := keyMap[key]; (!ok) || (ok && !v) {
cleanKeys = append(cleanKeys, key)
keyMap[key] = true
continue
}
log.Println("[GetSysConfigs] Warning: duplicate key:", key)
}
v, err := sc.Get(cleanKeys)
if err != nil {
return errors.Errorf("Load configs wrong: %s", err.Error())
}
t.keys = append(t.keys, cleanKeys...)
t.values = append(t.values, v...)
}
return nil
}
package master
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
"go.etcd.io/etcd/clientv3"
"google.golang.org/grpc"
)
func TestMaster_ConfigTask(t *testing.T) {
Init()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
etcdCli, err := clientv3.New(clientv3.Config{Endpoints: []string{Params.EtcdAddress}})
require.Nil(t, err)
_, err = etcdCli.Delete(ctx, "/test/root", clientv3.WithPrefix())
require.Nil(t, err)
Params = ParamTable{
Address: Params.Address,
Port: Params.Port,
EtcdAddress: Params.EtcdAddress,
EtcdRootPath: "/test/root",
PulsarAddress: Params.PulsarAddress,
ProxyIDList: []typeutil.UniqueID{1, 2},
WriteNodeIDList: []typeutil.UniqueID{3, 4},
TopicNum: 5,
QueryNodeNum: 3,
SoftTimeTickBarrierInterval: 300,
// segment
SegmentSize: 536870912 / 1024 / 1024,
SegmentSizeFactor: 0.75,
DefaultRecordSize: 1024,
MinSegIDAssignCnt: 1048576 / 1024,
MaxSegIDAssignCnt: Params.MaxSegIDAssignCnt,
SegIDAssignExpiration: 2000,
// msgChannel
ProxyTimeTickChannelNames: []string{"proxy1", "proxy2"},
WriteNodeTimeTickChannelNames: []string{"write3", "write4"},
InsertChannelNames: []string{"dm0", "dm1"},
K2SChannelNames: []string{"k2s0", "k2s1"},
QueryNodeStatsChannelName: "statistic",
MsgChannelSubName: Params.MsgChannelSubName,
}
svr, err := CreateServer(ctx)
require.Nil(t, err)
err = svr.Run(10002)
defer svr.Close()
require.Nil(t, err)
conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock())
require.Nil(t, err)
defer conn.Close()
cli := masterpb.NewMasterClient(conn)
testKeys := []string{
"/etcd/address",
"/master/port",
"/master/proxyidlist",
"/master/segmentthresholdfactor",
"/pulsar/token",
"/reader/stopflag",
"/proxy/timezone",
"/proxy/network/address",
"/proxy/storage/path",
"/storage/accesskey",
}
testVals := []string{
"localhost",
"53100",
"[1 2]",
"0.75",
"eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY",
"-1",
"UTC+8",
"0.0.0.0",
"/var/lib/milvus",
"",
}
sc := SysConfig{kv: svr.kvBase}
sc.InitFromFile(".")
configRequest := &internalpb.SysConfigRequest{
MsgType: internalpb.MsgType_kGetSysConfigs,
ReqID: 1,
Timestamp: 11,
ProxyID: 1,
Keys: testKeys,
KeyPrefixes: []string{},
}
response, err := cli.GetSysConfigs(ctx, configRequest)
assert.Nil(t, err)
assert.ElementsMatch(t, testKeys, response.Keys)
assert.ElementsMatch(t, testVals, response.Values)
assert.Equal(t, len(response.GetKeys()), len(response.GetValues()))
configRequest = &internalpb.SysConfigRequest{
MsgType: internalpb.MsgType_kGetSysConfigs,
ReqID: 1,
Timestamp: 11,
ProxyID: 1,
Keys: []string{},
KeyPrefixes: []string{"/master"},
}
response, err = cli.GetSysConfigs(ctx, configRequest)
assert.Nil(t, err)
for i := range response.GetKeys() {
assert.True(t, strings.HasPrefix(response.GetKeys()[i], "/master"))
}
assert.Equal(t, len(response.GetKeys()), len(response.GetValues()))
t.Run("Test duplicate keys and key prefix", func(t *testing.T) {
configRequest.Keys = []string{}
configRequest.KeyPrefixes = []string{"/master"}
resp, err := cli.GetSysConfigs(ctx, configRequest)
require.Nil(t, err)
assert.Equal(t, len(resp.GetKeys()), len(resp.GetValues()))
assert.NotEqual(t, 0, len(resp.GetKeys()))
configRequest.Keys = []string{"/master/port"}
configRequest.KeyPrefixes = []string{"/master"}
respDup, err := cli.GetSysConfigs(ctx, configRequest)
require.Nil(t, err)
assert.Equal(t, len(respDup.GetKeys()), len(respDup.GetValues()))
assert.NotEqual(t, 0, len(respDup.GetKeys()))
assert.Equal(t, len(respDup.GetKeys()), len(resp.GetKeys()))
})
}
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
master: # 21
address: localhost
port: 53100
pulsarmoniterinterval: 1
pulsartopic: "monitor-topic"
proxyidlist: [1, 2]
proxyTimeSyncChannels: ["proxy1", "proxy2"]
proxyTimeSyncSubName: "proxy-topic"
softTimeTickBarrierInterval: 500
writeidlist: [3, 4]
writeTimeSyncChannels: ["write3", "write4"]
writeTimeSyncSubName: "write-topic"
dmTimeSyncChannels: ["dm5", "dm6"]
k2sTimeSyncChannels: ["k2s7", "k2s8"]
defaultSizePerRecord: 1024
minimumAssignSize: 1048576
segmentThreshold: 536870912
segmentExpireDuration: 2000
segmentThresholdFactor: 0.75
querynodenum: 1
writenodenum: 1
statsChannels: "statistic"
etcd: # 4
address: localhost
port: 2379
rootpath: by-dev
segthreshold: 10000
timesync: # 1
interval: 400
storage: # 5
driver: TIKV
address: localhost
port: 2379
accesskey:
secretkey:
pulsar: # 6
authentication: false
user: user-default
token: eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY
address: localhost
port: 6650
topicnum: 128
reader: # 7
clientid: 0
stopflag: -1
readerqueuesize: 10000
searchchansize: 10000
key2segchansize: 10000
topicstart: 0
topicend: 128
writer: # 8
clientid: 0
stopflag: -2
readerqueuesize: 10000
searchbyidchansize: 10000
parallelism: 100
topicstart: 0
topicend: 128
bucket: "zilliz-hz"
proxy: # 21
timezone: UTC+8
proxy_id: 1
numReaderNodes: 2
tsoSaveInterval: 200
timeTickInterval: 200
pulsarTopics:
readerTopicPrefix: "milvusReader"
numReaderTopics: 2
deleteTopic: "milvusDeleter"
queryTopic: "milvusQuery"
resultTopic: "milvusResult"
resultGroup: "milvusResultGroup"
timeTickTopic: "milvusTimeTick"
network:
address: 0.0.0.0
port: 19530
logs:
level: debug
trace.enable: true
path: /tmp/logs
max_log_file_size: 1024MB
log_rotate_num: 0
storage:
path: /var/lib/milvus
auto_flush_interval: 1
......@@ -359,43 +359,6 @@ func (s *Master) ShowPartitions(ctx context.Context, in *internalpb.ShowPartitio
return t.(*showPartitionTask).stringListResponse, nil
}
func (s *Master) GetSysConfigs(ctx context.Context, in *internalpb.SysConfigRequest) (*servicepb.SysConfigResponse, error) {
var t task = &getSysConfigsTask{
req: in,
configkv: s.kvBase,
baseTask: baseTask{
sch: s.scheduler,
mt: s.metaTable,
cv: make(chan error),
},
keys: []string{},
values: []string{},
}
response := &servicepb.SysConfigResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
},
}
var err = s.scheduler.Enqueue(t)
if err != nil {
response.Status.Reason = "Enqueue failed: " + err.Error()
return response, nil
}
err = t.WaitToFinish(ctx)
if err != nil {
response.Status.Reason = "Get System Config failed: " + err.Error()
return response, nil
}
response.Keys = t.(*getSysConfigsTask).keys
response.Values = t.(*getSysConfigsTask).values
response.Status.ErrorCode = commonpb.ErrorCode_SUCCESS
return response, nil
}
//----------------------------------------Internal GRPC Service--------------------------------
func (s *Master) AllocTimestamp(ctx context.Context, request *internalpb.TsoRequest) (*internalpb.TsoResponse, error) {
......
......@@ -52,7 +52,7 @@ func (mt *metaTable) reloadFromKV() error {
for _, value := range values {
tenantMeta := pb.TenantMeta{}
err := proto.Unmarshal([]byte(value), &tenantMeta)
err := proto.UnmarshalText(value, &tenantMeta)
if err != nil {
return err
}
......@@ -66,7 +66,7 @@ func (mt *metaTable) reloadFromKV() error {
for _, value := range values {
proxyMeta := pb.ProxyMeta{}
err = proto.Unmarshal([]byte(value), &proxyMeta)
err = proto.UnmarshalText(value, &proxyMeta)
if err != nil {
return err
}
......@@ -80,7 +80,7 @@ func (mt *metaTable) reloadFromKV() error {
for _, value := range values {
collectionMeta := pb.CollectionMeta{}
err = proto.Unmarshal([]byte(value), &collectionMeta)
err = proto.UnmarshalText(value, &collectionMeta)
if err != nil {
return err
}
......@@ -95,7 +95,7 @@ func (mt *metaTable) reloadFromKV() error {
for _, value := range values {
segmentMeta := pb.SegmentMeta{}
err = proto.Unmarshal([]byte(value), &segmentMeta)
err = proto.UnmarshalText(value, &segmentMeta)
if err != nil {
return err
}
......@@ -107,10 +107,7 @@ func (mt *metaTable) reloadFromKV() error {
// metaTable.ddLock.Lock() before call this function
func (mt *metaTable) saveCollectionMeta(coll *pb.CollectionMeta) error {
collBytes, err := proto.Marshal(coll)
if err != nil {
return err
}
collBytes := proto.MarshalTextString(coll)
mt.collID2Meta[coll.ID] = *coll
mt.collName2ID[coll.Schema.Name] = coll.ID
return mt.client.Save("/collection/"+strconv.FormatInt(coll.ID, 10), string(collBytes))
......@@ -118,10 +115,7 @@ func (mt *metaTable) saveCollectionMeta(coll *pb.CollectionMeta) error {
// metaTable.ddLock.Lock() before call this function
func (mt *metaTable) saveSegmentMeta(seg *pb.SegmentMeta) error {
segBytes, err := proto.Marshal(seg)
if err != nil {
return err
}
segBytes := proto.MarshalTextString(seg)
mt.segID2Meta[seg.SegmentID] = *seg
......@@ -136,10 +130,7 @@ func (mt *metaTable) saveCollectionAndDeleteSegmentsMeta(coll *pb.CollectionMeta
}
kvs := make(map[string]string)
collStrs, err := proto.Marshal(coll)
if err != nil {
return err
}
collStrs := proto.MarshalTextString(coll)
kvs["/collection/"+strconv.FormatInt(coll.ID, 10)] = string(collStrs)
......@@ -159,19 +150,15 @@ func (mt *metaTable) saveCollectionAndDeleteSegmentsMeta(coll *pb.CollectionMeta
// metaTable.ddLock.Lock() before call this function
func (mt *metaTable) saveCollectionsAndSegmentsMeta(coll *pb.CollectionMeta, seg *pb.SegmentMeta) error {
kvs := make(map[string]string)
collBytes, err := proto.Marshal(coll)
if err != nil {
return err
}
collBytes := proto.MarshalTextString(coll)
kvs["/collection/"+strconv.FormatInt(coll.ID, 10)] = string(collBytes)
mt.collID2Meta[coll.ID] = *coll
mt.collName2ID[coll.Schema.Name] = coll.ID
segBytes, err := proto.Marshal(seg)
if err != nil {
return err
}
segBytes := proto.MarshalTextString(seg)
kvs["/segment/"+strconv.FormatInt(seg.SegmentID, 10)] = string(segBytes)
mt.segID2Meta[seg.SegmentID] = *seg
......@@ -220,7 +207,7 @@ func (mt *metaTable) AddCollection(coll *pb.CollectionMeta) error {
}
if len(coll.PartitionTags) == 0 {
coll.PartitionTags = append(coll.PartitionTags, "default")
coll.PartitionTags = append(coll.PartitionTags, Params.DefaultPartitionTag)
}
_, ok := mt.collName2ID[coll.Schema.Name]
if ok {
......@@ -292,6 +279,10 @@ func (mt *metaTable) AddPartition(collID UniqueID, tag string) error {
return errors.Errorf("can't find collection. id = " + strconv.FormatInt(collID, 10))
}
// number of partition tags (except _default) should be limited to 4096 by default
if int64(len(coll.PartitionTags)) > Params.MaxPartitionNum {
return errors.New("maximum partition's number should be limit to " + strconv.FormatInt(Params.MaxPartitionNum, 10))
}
for _, t := range coll.PartitionTags {
if t == tag {
return errors.Errorf("partition already exists.")
......@@ -326,17 +317,29 @@ func (mt *metaTable) DeletePartition(collID UniqueID, tag string) error {
mt.ddLock.Lock()
defer mt.ddLock.Unlock()
if tag == Params.DefaultPartitionTag {
return errors.New("default partition cannot be deleted")
}
collMeta, ok := mt.collID2Meta[collID]
if !ok {
return errors.Errorf("can't find collection. id = " + strconv.FormatInt(collID, 10))
}
// check tag exists
exist := false
pt := make([]string, 0, len(collMeta.PartitionTags))
for _, t := range collMeta.PartitionTags {
if t != tag {
pt = append(pt, t)
} else {
exist = true
}
}
if !exist {
return errors.New("partition " + tag + " does not exist")
}
if len(pt) == len(collMeta.PartitionTags) {
return nil
}
......
......@@ -3,6 +3,7 @@ package master
import (
"context"
"reflect"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
......@@ -238,6 +239,10 @@ func TestMetaTable_DeletePartition(t *testing.T) {
assert.Equal(t, 1, len(meta.collName2ID))
assert.Equal(t, 1, len(meta.collID2Meta))
assert.Equal(t, 1, len(meta.segID2Meta))
// delete not exist
err = meta.DeletePartition(100, "not_exist")
assert.NotNil(t, err)
}
func TestMetaTable_Segment(t *testing.T) {
......@@ -366,3 +371,39 @@ func TestMetaTable_UpdateSegment(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, seg.NumRows, int64(210))
}
func TestMetaTable_AddPartition_Limit(t *testing.T) {
Init()
Params.MaxPartitionNum = 256 // adding 4096 partitions is too slow
etcdAddr := Params.EtcdAddress
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
assert.Nil(t, err)
etcdKV := kv.NewEtcdKV(cli, "/etcd/test/root")
_, err = cli.Delete(context.TODO(), "/etcd/test/root", clientv3.WithPrefix())
assert.Nil(t, err)
meta, err := NewMetaTable(etcdKV)
assert.Nil(t, err)
defer meta.client.Close()
colMeta := pb.CollectionMeta{
ID: 100,
Schema: &schemapb.CollectionSchema{
Name: "coll1",
},
CreateTime: 0,
SegmentIDs: []UniqueID{},
PartitionTags: []string{},
}
err = meta.AddCollection(&colMeta)
assert.Nil(t, err)
for i := 0; i < int(Params.MaxPartitionNum); i++ {
err := meta.AddPartition(100, "partition_"+strconv.Itoa(i))
assert.Nil(t, err)
}
err = meta.AddPartition(100, "partition_limit")
assert.NotNil(t, err)
}
......@@ -43,6 +43,9 @@ type ParamTable struct {
K2SChannelNames []string
QueryNodeStatsChannelName string
MsgChannelSubName string
MaxPartitionNum int64
DefaultPartitionTag string
}
var Params ParamTable
......@@ -62,6 +65,10 @@ func (p *ParamTable) Init() {
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/common.yaml")
if err != nil {
panic(err)
}
// set members
p.initAddress()
......@@ -91,6 +98,8 @@ func (p *ParamTable) Init() {
p.initK2SChannelNames()
p.initQueryNodeStatsChannelName()
p.initMsgChannelSubName()
p.initMaxPartitionNum()
p.initDefaultPartitionTag()
}
func (p *ParamTable) initAddress() {
......@@ -360,18 +369,33 @@ func (p *ParamTable) initInsertChannelNames() {
if err != nil {
log.Fatal(err)
}
id, err := p.Load("nodeID.queryNodeIDList")
channelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
panic(err)
}
ids := strings.Split(id, ",")
channels := make([]string, 0, len(ids))
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
}
channels = append(channels, ch+"-"+i)
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
p.InsertChannelNames = channels
}
......@@ -396,3 +420,24 @@ func (p *ParamTable) initK2SChannelNames() {
}
p.K2SChannelNames = channels
}
func (p *ParamTable) initMaxPartitionNum() {
str, err := p.Load("master.maxPartitionNum")
if err != nil {
panic(err)
}
maxPartitionNum, err := strconv.ParseInt(str, 10, 64)
if err != nil {
panic(err)
}
p.MaxPartitionNum = maxPartitionNum
}
func (p *ParamTable) initDefaultPartitionTag() {
defaultTag, err := p.Load("common.defaultPartitionTag")
if err != nil {
panic(err)
}
p.DefaultPartitionTag = defaultTag
}
......@@ -31,7 +31,7 @@ func TestParamTable_EtcdRootPath(t *testing.T) {
func TestParamTable_TopicNum(t *testing.T) {
Params.Init()
num := Params.TopicNum
assert.Equal(t, num, 15)
assert.Equal(t, num, 1)
}
func TestParamTable_SegmentSize(t *testing.T) {
......@@ -73,7 +73,7 @@ func TestParamTable_SegIDAssignExpiration(t *testing.T) {
func TestParamTable_QueryNodeNum(t *testing.T) {
Params.Init()
num := Params.QueryNodeNum
assert.Equal(t, num, 2)
assert.Equal(t, num, 1)
}
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
......@@ -85,17 +85,15 @@ func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
func TestParamTable_ProxyIDList(t *testing.T) {
Params.Init()
ids := Params.ProxyIDList
assert.Equal(t, len(ids), 2)
assert.Equal(t, len(ids), 1)
assert.Equal(t, ids[0], int64(1))
assert.Equal(t, ids[1], int64(2))
}
func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) {
Params.Init()
names := Params.ProxyTimeTickChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "proxyTimeTick-1")
assert.Equal(t, names[1], "proxyTimeTick-2")
}
func TestParamTable_MsgChannelSubName(t *testing.T) {
......@@ -113,31 +111,27 @@ func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) {
func TestParamTable_WriteNodeIDList(t *testing.T) {
Params.Init()
ids := Params.WriteNodeIDList
assert.Equal(t, len(ids), 2)
assert.Equal(t, ids[0], int64(5))
assert.Equal(t, ids[1], int64(6))
assert.Equal(t, len(ids), 1)
assert.Equal(t, ids[0], int64(3))
}
func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
Params.Init()
names := Params.WriteNodeTimeTickChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, names[0], "writeNodeTimeTick-5")
assert.Equal(t, names[1], "writeNodeTimeTick-6")
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "writeNodeTimeTick-3")
}
func TestParamTable_InsertChannelNames(t *testing.T) {
Params.Init()
names := Params.InsertChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, names[0], "insert-3")
assert.Equal(t, names[1], "insert-4")
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "insert-0")
}
func TestParamTable_K2SChannelNames(t *testing.T) {
Params.Init()
names := Params.K2SChannelNames
assert.Equal(t, len(names), 2)
assert.Equal(t, names[0], "k2s-5")
assert.Equal(t, names[1], "k2s-6")
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "k2s-3")
}
......@@ -191,10 +191,12 @@ func (t *showPartitionTask) Execute() error {
return errors.New("null request")
}
partitions := make([]string, 0)
for _, collection := range t.mt.collID2Meta {
partitions = append(partitions, collection.PartitionTags...)
collMeta, err := t.mt.GetCollectionByName(t.req.CollectionName.CollectionName)
if err != nil {
return err
}
partitions := make([]string, 0)
partitions = append(partitions, collMeta.PartitionTags...)
stringListResponse := servicepb.StringListResponse{
Status: &commonpb.Status{
......
......@@ -60,6 +60,9 @@ func TestMaster_Partition(t *testing.T) {
K2SChannelNames: []string{"k2s0", "k2s1"},
QueryNodeStatsChannelName: "statistic",
MsgChannelSubName: Params.MsgChannelSubName,
MaxPartitionNum: int64(4096),
DefaultPartitionTag: "_default",
}
port := 10000 + rand.Intn(1000)
......@@ -212,7 +215,7 @@ func TestMaster_Partition(t *testing.T) {
//assert.Equal(t, collMeta.PartitionTags[0], "partition1")
//assert.Equal(t, collMeta.PartitionTags[1], "partition2")
assert.ElementsMatch(t, []string{"default", "partition1", "partition2"}, collMeta.PartitionTags)
assert.ElementsMatch(t, []string{"_default", "partition1", "partition2"}, collMeta.PartitionTags)
showPartitionReq := internalpb.ShowPartitionRequest{
MsgType: internalpb.MsgType_kShowPartitions,
......@@ -224,7 +227,7 @@ func TestMaster_Partition(t *testing.T) {
stringList, err := cli.ShowPartitions(ctx, &showPartitionReq)
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"default", "partition1", "partition2"}, stringList.Values)
assert.ElementsMatch(t, []string{"_default", "partition1", "partition2"}, stringList.Values)
showPartitionReq = internalpb.ShowPartitionRequest{
MsgType: internalpb.MsgType_kShowPartitions,
......
......@@ -261,6 +261,9 @@ func startupMaster() {
K2SChannelNames: []string{"k2s0", "k2s1"},
QueryNodeStatsChannelName: "statistic",
MsgChannelSubName: Params.MsgChannelSubName,
MaxPartitionNum: int64(4096),
DefaultPartitionTag: "_default",
}
master, err = CreateServer(ctx)
......
package master
import (
"fmt"
"log"
"os"
"path"
"path/filepath"
"strings"
"github.com/spf13/viper"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/kv"
)
type SysConfig struct {
kv *kv.EtcdKV
}
// Initialize Configs from config files, and store them in Etcd.
func (conf *SysConfig) InitFromFile(filePath string) error {
memConfigs, err := conf.getConfigFiles(filePath)
if err != nil {
return errors.Errorf("[Init SysConfig] %s\n", err.Error())
}
for _, memConfig := range memConfigs {
if err := conf.saveToEtcd(memConfig, "config"); err != nil {
return errors.Errorf("[Init SysConfig] %s\n", err.Error())
}
}
return nil
}
func (conf *SysConfig) GetByPrefix(keyPrefix string) (keys []string, values []string, err error) {
realPrefix := path.Join("config", strings.ToLower(keyPrefix))
keys, values, err = conf.kv.LoadWithPrefix(realPrefix)
for index := range keys {
keys[index] = strings.Replace(keys[index], conf.kv.GetPath("config"), "", 1)
}
if err != nil {
return nil, nil, err
}
log.Println("Loaded", len(keys), "pairs of configs with prefix", keyPrefix)
return keys, values, err
}
// Get specific configs for keys.
func (conf *SysConfig) Get(keys []string) ([]string, error) {
var keysToLoad []string
for i := range keys {
keysToLoad = append(keysToLoad, path.Join("config", strings.ToLower(keys[i])))
}
values, err := conf.kv.MultiLoad(keysToLoad)
if err != nil {
return nil, err
}
return values, nil
}
func (conf *SysConfig) getConfigFiles(filePath string) ([]*viper.Viper, error) {
var vipers []*viper.Viper
err := filepath.Walk(filePath,
func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// all names
if !info.IsDir() && filepath.Ext(path) == ".yaml" {
log.Println("Config files ", info.Name())
currentConf := viper.New()
currentConf.SetConfigFile(path)
if err := currentConf.ReadInConfig(); err != nil {
log.Panic("Config file error: ", err)
}
vipers = append(vipers, currentConf)
}
return nil
})
if err != nil {
return nil, err
}
if len(vipers) == 0 {
return nil, errors.Errorf("There are no config files in the path `%s`.\n", filePath)
}
return vipers, nil
}
func (conf *SysConfig) saveToEtcd(memConfig *viper.Viper, secondRootPath string) error {
configMaps := map[string]string{}
allKeys := memConfig.AllKeys()
for _, key := range allKeys {
etcdKey := strings.ReplaceAll(key, ".", "/")
etcdKey = path.Join(secondRootPath, etcdKey)
val := memConfig.Get(key)
if val == nil {
configMaps[etcdKey] = ""
continue
}
configMaps[etcdKey] = fmt.Sprintf("%v", val)
}
if err := conf.kv.MultiSave(configMaps); err != nil {
return err
}
return nil
}
package master
import (
"context"
"fmt"
"log"
"path"
"strings"
"testing"
"time"
"github.com/zilliztech/milvus-distributed/internal/kv"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/clientv3"
)
func Test_SysConfig(t *testing.T) {
Init()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
cli, err := clientv3.New(clientv3.Config{
Endpoints: []string{Params.EtcdAddress},
DialTimeout: 5 * time.Second,
})
require.Nil(t, err)
_, err = cli.Delete(ctx, "/test/root", clientv3.WithPrefix())
require.Nil(t, err)
rootPath := "/test/root"
configKV := kv.NewEtcdKV(cli, rootPath)
defer configKV.Close()
sc := SysConfig{kv: configKV}
require.Equal(t, rootPath, sc.kv.GetPath("."))
t.Run("tests on contig_test.yaml", func(t *testing.T) {
err = sc.InitFromFile(".")
require.Nil(t, err)
testKeys := []string{
"/etcd/address",
"/master/port",
"/master/proxyidlist",
"/master/segmentthresholdfactor",
"/pulsar/token",
"/reader/stopflag",
"/proxy/timezone",
"/proxy/network/address",
"/proxy/storage/path",
"/storage/accesskey",
}
testVals := []string{
"localhost",
"53100",
"[1 2]",
"0.75",
"eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY",
"-1",
"UTC+8",
"0.0.0.0",
"/var/lib/milvus",
"",
}
vals, err := sc.Get(testKeys)
assert.Nil(t, err)
for i := range testVals {
assert.Equal(t, testVals[i], vals[i])
}
keys, vals, err := sc.GetByPrefix("/master")
assert.Nil(t, err)
for i := range keys {
assert.True(t, strings.HasPrefix(keys[i], "/master/"))
}
assert.Equal(t, len(keys), len(vals))
assert.Equal(t, 21, len(keys))
// Test get all configs
keys, vals, err = sc.GetByPrefix("/")
assert.Nil(t, err)
assert.Equal(t, len(keys), len(vals))
assert.Equal(t, 73, len(vals))
// Test get configs with prefix not exist
keys, vals, err = sc.GetByPrefix("/config")
assert.Nil(t, err)
assert.Equal(t, len(keys), len(vals))
assert.Equal(t, 0, len(keys))
assert.Equal(t, 0, len(vals))
_, _, err = sc.GetByPrefix("//././../../../../../..//../")
assert.Nil(t, err)
_, _, err = sc.GetByPrefix("/master/./address")
assert.Nil(t, err)
_, _, err = sc.GetByPrefix(".")
assert.Nil(t, err)
_, _, err = sc.GetByPrefix("\\")
assert.Nil(t, err)
})
t.Run("getConfigFiles", func(t *testing.T) {
filePath := "../../configs"
vipers, err := sc.getConfigFiles(filePath)
assert.Nil(t, err)
assert.NotNil(t, vipers[0])
filePath = "/path/not/exists"
_, err = sc.getConfigFiles(filePath)
assert.NotNil(t, err)
log.Println(err)
})
t.Run("Test saveToEtcd Normal", func(t *testing.T) {
_, err = cli.Delete(ctx, "/test/root/config", clientv3.WithPrefix())
require.Nil(t, err)
v := viper.New()
v.Set("a.suba1", "v1")
v.Set("a.suba2", "v2")
v.Set("a.suba3.subsuba1", "v3")
v.Set("a.suba3.subsuba2", "v4")
secondRootPath := "config"
err := sc.saveToEtcd(v, secondRootPath)
assert.Nil(t, err)
value, err := sc.kv.Load(path.Join(secondRootPath, "a/suba1"))
assert.Nil(t, err)
assert.Equal(t, "v1", value)
value, err = sc.kv.Load(path.Join(secondRootPath, "a/suba2"))
assert.Nil(t, err)
assert.Equal(t, "v2", value)
value, err = sc.kv.Load(path.Join(secondRootPath, "a/suba3/subsuba1"))
assert.Nil(t, err)
assert.Equal(t, "v3", value)
value, err = sc.kv.Load(path.Join(secondRootPath, "a/suba3/subsuba2"))
assert.Nil(t, err)
assert.Equal(t, "v4", value)
keys, values, err := sc.kv.LoadWithPrefix(path.Join(secondRootPath, "a"))
assert.Nil(t, err)
assert.Equal(t, 4, len(keys))
assert.Equal(t, 4, len(values))
assert.ElementsMatch(t, []string{
path.Join(sc.kv.GetPath(secondRootPath), "/a/suba1"),
path.Join(sc.kv.GetPath(secondRootPath), "/a/suba2"),
path.Join(sc.kv.GetPath(secondRootPath), "/a/suba3/subsuba1"),
path.Join(sc.kv.GetPath(secondRootPath), "/a/suba3/subsuba2"),
}, keys)
assert.ElementsMatch(t, []string{"v1", "v2", "v3", "v4"}, values)
keys = []string{
"/a/suba1",
"/a/suba2",
"/a/suba3/subsuba1",
"/a/suba3/subsuba2",
}
values, err = sc.Get(keys)
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"v1", "v2", "v3", "v4"}, values)
keysAfter, values, err := sc.GetByPrefix("/a")
fmt.Println(keysAfter)
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"v1", "v2", "v3", "v4"}, values)
assert.ElementsMatch(t, keys, keysAfter)
})
t.Run("Test saveToEtcd Different value types", func(t *testing.T) {
v := viper.New()
v.Set("string", "string")
v.Set("number", 1)
v.Set("nil", nil)
v.Set("float", 1.2)
v.Set("intslice", []int{100, 200})
v.Set("stringslice", []string{"a", "b"})
v.Set("stringmapstring", map[string]string{"k1": "1", "k2": "2"})
secondRootPath := "test_save_to_etcd_different_value_types"
err := sc.saveToEtcd(v, secondRootPath)
require.Nil(t, err)
keys, values, err := sc.kv.LoadWithPrefix(path.Join("/", secondRootPath))
assert.Nil(t, err)
assert.Equal(t, 7, len(keys))
assert.Equal(t, 7, len(values))
assert.ElementsMatch(t, []string{
path.Join(sc.kv.GetPath(secondRootPath), "nil"),
path.Join(sc.kv.GetPath(secondRootPath), "string"),
path.Join(sc.kv.GetPath(secondRootPath), "number"),
path.Join(sc.kv.GetPath(secondRootPath), "float"),
path.Join(sc.kv.GetPath(secondRootPath), "intslice"),
path.Join(sc.kv.GetPath(secondRootPath), "stringslice"),
path.Join(sc.kv.GetPath(secondRootPath), "stringmapstring"),
}, keys)
assert.ElementsMatch(t, []string{"", "string", "1", "1.2", "[100 200]", "[a b]", "map[k1:1 k2:2]"}, values)
})
}
......@@ -70,7 +70,7 @@ func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) {
for i := 0; i < len(channels); i++ {
pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]})
if err != nil {
log.Printf("Failed to create reader producer %s, error = %v", channels[i], err)
log.Printf("Failed to create querynode producer %s, error = %v", channels[i], err)
}
ms.producers = append(ms.producers, &pp)
}
......@@ -141,6 +141,15 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error {
hashValues := tsMsg.HashKeys()
bucketValues := make([]int32, len(hashValues))
for index, hashValue := range hashValues {
if tsMsg.Type() == internalPb.MsgType_kSearchResult {
searchResult := tsMsg.(*SearchResultMsg)
channelID := int32(searchResult.ResultChannelID)
if channelID >= int32(len(ms.producers)) {
return errors.New("Failed to produce pulsar msg to unKnow channel")
}
bucketValues[index] = channelID
continue
}
bucketValues[index] = hashValue % int32(len(ms.producers))
}
reBucketValues[channelID] = bucketValues
......
......@@ -50,7 +50,7 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) TsMsg {
CollectionName: "Collection",
PartitionTag: "Partition",
SegmentID: 1,
ChannelID: 1,
ChannelID: 0,
ProxyID: 1,
Timestamps: []Timestamp{1},
RowIDs: []int64{1},
......@@ -82,7 +82,7 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) TsMsg {
ReqID: reqID,
ProxyID: 1,
Timestamp: 1,
ResultChannelID: 1,
ResultChannelID: 0,
}
searchMsg := &SearchMsg{
BaseMsg: baseMsg,
......@@ -97,7 +97,7 @@ func getTsMsg(msgType MsgType, reqID UniqueID, hashValue int32) TsMsg {
ProxyID: 1,
QueryNodeID: 1,
Timestamp: 1,
ResultChannelID: 1,
ResultChannelID: 0,
}
searchResultMsg := &SearchResultMsg{
BaseMsg: baseMsg,
......
......@@ -14,7 +14,6 @@ enum MsgType {
kHasCollection = 102;
kDescribeCollection = 103;
kShowCollections = 104;
kGetSysConfigs = 105;
/* Definition Requests: partition */
kCreatePartition = 200;
......@@ -34,7 +33,6 @@ enum MsgType {
/* System Control */
kTimeTick = 1200;
kQueryNodeSegStats = 1201;
}
enum PeerRole {
......@@ -225,19 +223,6 @@ message SearchRequest {
}
/**
* @brief Request of DescribePartition
*/
message SysConfigRequest {
MsgType msg_type = 1;
int64 reqID = 2;
int64 proxyID = 3;
uint64 timestamp = 4;
repeated string keys = 5;
repeated string key_prefixes = 6;
}
message SearchResult {
MsgType msg_type = 1;
common.Status status = 2;
......@@ -246,7 +231,7 @@ message SearchResult {
int64 query_nodeID = 5;
uint64 timestamp = 6;
int64 result_channelID = 7;
repeated service.Hits hits = 8;
repeated bytes hits = 8;
}
message TimeTickMsg {
......@@ -281,4 +266,4 @@ message QueryNodeSegStats {
MsgType msg_type = 1;
int64 peerID = 2;
repeated SegmentStats seg_stats = 3;
}
}
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册