提交 dd20ab44 编写于 作者: D dragondriver 提交者: yefu.chen

Add interface of check param util

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 48522d72
package proxynode
import (
"strconv"
)
const (
L2 = "L2"
IP = "IP"
HAMMING = "HAMMING"
JACCARD = "JACCARD"
TANIMOTO = "TANIMOTO"
SUBSTRUCTURE = "SUBSTRUCTURE"
SUPERSTRUCTURE = "SUPERSTRUCTURE"
MinNBits = 1
MaxNBits = 16
DefaultNBits = 8
MinNList = 1
MaxNList = 65536
DefaultMinDim = 1
DefaultMaxDim = 32768
NgtMinEdgeSize = 1
NgtMaxEdgeSize = 200
HNSWMinEfConstruction = 8
HNSWMaxEfConstruction = 512
HNSWMinM = 4
HNSWMaxM = 64
MinKNNG = 5
MaxKNNG = 300
MinSearchLength = 10
MaxSearchLength = 300
MinOutDegree = 5
MaxOutDegree = 300
MinCandidatePoolSize = 50
MaxCandidatePoolSize = 1000
MinNTrees = 1
// too large of n_trees takes much time, if there is real requirement, change this threshold.
MaxNTrees = 1024
DIM = "dim"
Metric = "metric_type"
NLIST = "nlist"
NBITS = "nbits"
IVFM = "m"
KNNG = "knng"
SearchLength = "search_length"
OutDegree = "out_degree"
CANDIDATE = "candidate_pool_size"
EFConstruction = "efConstruction"
HNSWM = "M"
PQM = "PQM"
NTREES = "n_trees"
EdgeSize = "edge_size"
ForcedlyPrunedEdgeSize = "forcedly_pruned_edge_size"
SelectivelyPrunedEdgeSize = "selectively_pruned_edge_size"
OutgoingEdgeSize = "outgoing_edge_size"
IncomingEdgeSize = "incoming_edge_size"
IndexMode = "index_mode"
CPUMode = "CPU"
GPUMode = "GPU"
)
var METRICS = []string{L2, IP} // const
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUBSTRUCTURE} // const
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
type ConfAdapter interface {
CheckTrain(map[string]string) bool
}
func CheckIntByRange(params map[string]string, key string, min, max int) bool {
valueStr, ok := params[key]
if !ok {
return false
}
value, err := strconv.Atoi(valueStr)
if err != nil {
return false
}
return value >= min && value <= max
}
func CheckStrByValues(params map[string]string, key string, container []string) bool {
value, ok := params[key]
if !ok {
return false
}
return SliceContain(container, value)
}
type BaseConfAdapter struct {
}
func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
return CheckStrByValues(params, Metric, METRICS)
}
func newBaseConfAdapter() *BaseConfAdapter {
return &BaseConfAdapter{}
}
type IVFConfAdapter struct {
BaseConfAdapter
}
func (adapter *IVFConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return false
}
// skip check number of rows
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newIVFConfAdapter() *IVFConfAdapter {
return &IVFConfAdapter{}
}
type IVFPQConfAdapter struct {
IVFConfAdapter
}
func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool {
if !adapter.IVFConfAdapter.CheckTrain(params) {
return false
}
if !CheckIntByRange(params, NBITS, MinNBits, MaxNBits) {
return false
}
return adapter.checkPQParams(params)
}
func (adapter *IVFPQConfAdapter) checkPQParams(params map[string]string) bool {
dimension, _ := strconv.Atoi(params[DIM])
nbits, _ := strconv.Atoi(params[NBITS])
mStr, ok := params[IVFM]
if !ok {
return false
}
m, err := strconv.Atoi(mStr)
if err != nil {
return false
}
mode, ok := params[IndexMode]
if !ok {
mode = CPUMode
}
if mode == GPUMode && !adapter.checkGPUPQParams(dimension, m, nbits) {
return false
}
return adapter.checkCPUPQParams(dimension, m)
}
func (adapter *IVFPQConfAdapter) checkGPUPQParams(dimension, m, nbits int) bool {
/*
* Faiss 1.6
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
* no precomputed codes. Precomputed codes supports any number of dimensions, but will involve memory overheads.
*/
subDim := dimension / m
return SliceContain(supportSubQuantizer, m) && SliceContain(supportDimPerSubQuantizer, subDim) && nbits == 8
}
func (adapter *IVFPQConfAdapter) checkCPUPQParams(dimension, m int) bool {
return (dimension % m) == 0
}
func newIVFPQConfAdapter() *IVFPQConfAdapter {
return &IVFPQConfAdapter{}
}
type IVFSQConfAdapter struct {
IVFConfAdapter
}
func (adapter *IVFSQConfAdapter) CheckTrain(params map[string]string) bool {
params[NBITS] = strconv.Itoa(DefaultNBits)
return adapter.IVFConfAdapter.CheckTrain(params)
}
func newIVFSQConfAdapter() *IVFSQConfAdapter {
return &IVFSQConfAdapter{}
}
type BinIDMAPConfAdapter struct {
}
func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
return CheckStrByValues(params, Metric, BinIDMapMetrics)
}
func newBinIDMAPConfAdapter() *BinIDMAPConfAdapter {
return &BinIDMAPConfAdapter{}
}
type BinIVFConfAdapter struct {
}
func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
return false
}
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
return false
}
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
return false
}
// skip checking the number of rows
return true
}
func newBinIVFConfAdapter() *BinIVFConfAdapter {
return &BinIVFConfAdapter{}
}
type NSGConfAdapter struct {
}
func (adapter *NSGConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckStrByValues(params, Metric, METRICS) {
return false
}
if !CheckIntByRange(params, KNNG, MinKNNG, MaxKNNG) {
return false
}
if !CheckIntByRange(params, SearchLength, MinSearchLength, MaxSearchLength) {
return false
}
if !CheckIntByRange(params, OutDegree, MinOutDegree, MaxOutDegree) {
return false
}
if !CheckIntByRange(params, CANDIDATE, MinCandidatePoolSize, MaxCandidatePoolSize) {
return false
}
// skip checking the number of rows
return true
}
func newNSGConfAdapter() *NSGConfAdapter {
return &NSGConfAdapter{}
}
type HNSWConfAdapter struct {
BaseConfAdapter
}
func (adapter *HNSWConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
return false
}
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newHNSWConfAdapter() *HNSWConfAdapter {
return &HNSWConfAdapter{}
}
type ANNOYConfAdapter struct {
BaseConfAdapter
}
func (adapter *ANNOYConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, NTREES, MinNTrees, MaxNTrees) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newANNOYConfAdapter() *ANNOYConfAdapter {
return &ANNOYConfAdapter{}
}
type RHNSWFlatConfAdapter struct {
BaseConfAdapter
}
func (adapter *RHNSWFlatConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
return false
}
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newRHNSWFlatConfAdapter() *RHNSWFlatConfAdapter {
return &RHNSWFlatConfAdapter{}
}
type RHNSWPQConfAdapter struct {
BaseConfAdapter
IVFPQConfAdapter
}
func (adapter *RHNSWPQConfAdapter) CheckTrain(params map[string]string) bool {
if !adapter.BaseConfAdapter.CheckTrain(params) {
return false
}
if !CheckIntByRange(params, EFConstruction, HNSWMinM, HNSWMaxM) {
return false
}
dimension, _ := strconv.Atoi(params[DIM])
pqmStr, ok := params[PQM]
if !ok {
return false
}
pqm, err := strconv.Atoi(pqmStr)
if err != nil {
return false
}
return adapter.IVFPQConfAdapter.checkCPUPQParams(dimension, pqm)
}
func newRHNSWPQConfAdapter() *RHNSWPQConfAdapter {
return &RHNSWPQConfAdapter{}
}
type RHNSWSQConfAdapter struct {
BaseConfAdapter
}
func (adapter *RHNSWSQConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
return false
}
if !CheckIntByRange(params, HNSWM, HNSWMinM, HNSWMaxM) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newRHNSWSQConfAdapter() *RHNSWSQConfAdapter {
return &RHNSWSQConfAdapter{}
}
type NGTPANNGConfAdapter struct {
BaseConfAdapter
}
func (adapter *NGTPANNGConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, EdgeSize, NgtMinEdgeSize, NgtMaxEdgeSize) {
return false
}
if !CheckIntByRange(params, ForcedlyPrunedEdgeSize, NgtMinEdgeSize, NgtMaxEdgeSize) {
return false
}
if !CheckIntByRange(params, SelectivelyPrunedEdgeSize, NgtMinEdgeSize, NgtMaxEdgeSize) {
return false
}
selectivelyPrunedEdgeSize, _ := strconv.Atoi(params[SelectivelyPrunedEdgeSize])
forcedlyPrunedEdgeSize, _ := strconv.Atoi(params[ForcedlyPrunedEdgeSize])
if selectivelyPrunedEdgeSize >= forcedlyPrunedEdgeSize {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newNGTPANNGConfAdapter() *NGTPANNGConfAdapter {
return &NGTPANNGConfAdapter{}
}
type NGTONNGConfAdapter struct {
BaseConfAdapter
}
func (adapter *NGTONNGConfAdapter) CheckTrain(params map[string]string) bool {
if !CheckIntByRange(params, EdgeSize, NgtMinEdgeSize, NgtMaxEdgeSize) {
return false
}
if !CheckIntByRange(params, OutgoingEdgeSize, NgtMinEdgeSize, NgtMaxEdgeSize) {
return false
}
if !CheckIntByRange(params, IncomingEdgeSize, NgtMinEdgeSize, NgtMaxEdgeSize) {
return false
}
return adapter.BaseConfAdapter.CheckTrain(params)
}
func newNGTONNGConfAdapter() *NGTONNGConfAdapter {
return &NGTONNGConfAdapter{}
}
package proxynode
import (
"errors"
"sync"
)
type ConfAdapterMgr interface {
GetAdapter(indexType string) (ConfAdapter, error)
}
type ConfAdapterMgrImpl struct {
init bool
adapters map[string]ConfAdapter
}
func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) {
if !mgr.init {
mgr.registerConfAdapter()
}
adapter, ok := mgr.adapters[indexType]
if ok {
return adapter, nil
}
return nil, errors.New("Can not find conf adapter: " + indexType)
}
func (mgr *ConfAdapterMgrImpl) registerConfAdapter() {
mgr.init = true
mgr.adapters[IndexFaissIdmap] = newBaseConfAdapter()
mgr.adapters[IndexFaissIvfflat] = newIVFConfAdapter()
mgr.adapters[IndexFaissIvfpq] = newIVFPQConfAdapter()
mgr.adapters[IndexFaissIvfsq8] = newIVFSQConfAdapter()
mgr.adapters[IndexFaissIvfsq8h] = newIVFSQConfAdapter()
mgr.adapters[IndexFaissBinIdmap] = newBinIDMAPConfAdapter()
mgr.adapters[IndexFaissBinIvfflat] = newBinIVFConfAdapter()
mgr.adapters[IndexNsg] = newNSGConfAdapter()
mgr.adapters[IndexHnsw] = newHNSWConfAdapter()
mgr.adapters[IndexAnnoy] = newANNOYConfAdapter()
mgr.adapters[IndexRhnswflat] = newRHNSWFlatConfAdapter()
mgr.adapters[IndexRhnswpq] = newRHNSWPQConfAdapter()
mgr.adapters[IndexRhnswsq] = newRHNSWSQConfAdapter()
mgr.adapters[IndexNgtpanng] = newNGTPANNGConfAdapter()
mgr.adapters[IndexNgtonng] = newNGTONNGConfAdapter()
}
func newConfAdapterMgrImpl() *ConfAdapterMgrImpl {
return &ConfAdapterMgrImpl{}
}
var confAdapterMgr ConfAdapterMgr
var getConfAdapterMgrOnce sync.Once
func GetConfAdapterMgrInstance() ConfAdapterMgr {
getConfAdapterMgrOnce.Do(func() {
confAdapterMgr = newConfAdapterMgrImpl()
})
return confAdapterMgr
}
package proxynode
const (
IndexFaissIdmap = "FLAT"
IndexFaissIvfflat = "IVF_FLAT"
IndexFaissIvfpq = "IVF_PQ"
IndexFaissIvfsq8 = "IVF_SQ8"
IndexFaissIvfsq8h = "IVF_SQ8_HYBRID"
IndexFaissBinIdmap = "BIN_FLAT"
IndexFaissBinIvfflat = "BIN_IVF_FLAT"
IndexNsg = "NSG"
IndexHnsw = "HNSW"
IndexRhnswflat = "RHNSW_FLAT"
IndexRhnswpq = "RHNSW_PQ"
IndexRhnswsq = "RHNSW_SQ"
IndexAnnoy = "ANNOY"
IndexNgtpanng = "NGT_PANNG"
IndexNgtonng = "NGT_ONNG"
)
......@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"sync"
......@@ -13,61 +12,6 @@ import (
"go.uber.org/zap"
)
func SliceContain(s interface{}, item interface{}) bool {
ss := reflect.ValueOf(s)
if ss.Kind() != reflect.Slice {
panic("SliceContain expect a slice")
}
for i := 0; i < ss.Len(); i++ {
if ss.Index(i).Interface() == item {
return true
}
}
return false
}
func SliceSetEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if !SliceContain(s2, ss1.Index(i).Interface()) {
return false
}
}
return true
}
func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if ss2.Index(i).Interface() != ss1.Index(i).Interface() {
return false
}
}
return true
}
type InsertChannelsMap struct {
collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams
insertChannels [][]string // it's a little confusing to use []string as the key of map
......
......@@ -4,6 +4,7 @@ import (
"encoding/json"
"io/ioutil"
"net/http"
"reflect"
"time"
"go.uber.org/zap"
......@@ -41,3 +42,58 @@ func GetPulsarConfig(protocol, ip, port, url string) (map[string]interface{}, er
return ret, nil
}
func SliceContain(s interface{}, item interface{}) bool {
ss := reflect.ValueOf(s)
if ss.Kind() != reflect.Slice {
panic("SliceContain expect a slice")
}
for i := 0; i < ss.Len(); i++ {
if ss.Index(i).Interface() == item {
return true
}
}
return false
}
func SliceSetEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if !SliceContain(s2, ss1.Index(i).Interface()) {
return false
}
}
return true
}
func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if ss2.Index(i).Interface() != ss1.Index(i).Interface() {
return false
}
}
return true
}
......@@ -3,6 +3,7 @@ package proxynode
import (
"fmt"
"net/http"
"sort"
"strconv"
"testing"
......@@ -36,3 +37,99 @@ func TestGetPulsarConfig(t *testing.T) {
assert.Equal(t, fmt.Sprintf("%v", value), fmt.Sprintf("%v", runtimeConfig[key]))
}
}
func TestSliceContain(t *testing.T) {
strSlice := []string{"test", "for", "SliceContain"}
intSlice := []int{1, 2, 3}
cases := []struct {
s interface{}
item interface{}
want bool
}{
{strSlice, "test", true},
{strSlice, "for", true},
{strSlice, "SliceContain", true},
{strSlice, "tests", false},
{intSlice, 1, true},
{intSlice, 2, true},
{intSlice, 3, true},
{intSlice, 4, false},
}
for _, test := range cases {
if got := SliceContain(test.s, test.item); got != test.want {
t.Errorf("SliceContain(%v, %v) = %v", test.s, test.item, test.want)
}
}
}
func TestSliceSetEqual(t *testing.T) {
cases := []struct {
s1 interface{}
s2 interface{}
want bool
}{
{[]int{}, []int{}, true},
{[]string{}, []string{}, true},
{[]int{1, 2, 3}, []int{3, 2, 1}, true},
{[]int{1, 2, 3}, []int{1, 2, 3}, true},
{[]int{1, 2, 3}, []int{}, false},
{[]int{1, 2, 3}, []int{1, 2}, false},
{[]int{1, 2, 3}, []int{4, 5, 6}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{"SliceSetEqual", "test", "for"}, true},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for", "SliceSetEqual"}, true},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for"}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for", "SliceContain"}, false},
}
for _, test := range cases {
if got := SliceSetEqual(test.s1, test.s2); got != test.want {
t.Errorf("SliceSetEqual(%v, %v) = %v", test.s1, test.s2, test.want)
}
}
}
func TestSortedSliceEqual(t *testing.T) {
sortSlice := func(slice interface{}, less func(i, j int) bool) {
sort.Slice(slice, less)
}
intSliceAfterSort := func(slice []int) []int {
sortSlice(slice, func(i, j int) bool {
return slice[i] <= slice[j]
})
return slice
}
stringSliceAfterSort := func(slice []string) []string {
sortSlice(slice, func(i, j int) bool {
return slice[i] <= slice[j]
})
return slice
}
cases := []struct {
s1 interface{}
s2 interface{}
want bool
}{
{intSliceAfterSort([]int{}), intSliceAfterSort([]int{}), true},
{stringSliceAfterSort([]string{}), stringSliceAfterSort([]string{}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{3, 2, 1}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{1, 2, 3}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{}), false},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{1, 2}), false},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{4, 5, 6}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"SliceSetEqual", "test", "for"}), true},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), true},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for"}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for", "SliceContain"}), false},
}
for _, test := range cases {
if got := SortedSliceEqual(test.s1, test.s2); got != test.want {
t.Errorf("SliceSetEqual(%v, %v) = %v", test.s1, test.s2, test.want)
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册