提交 dfa35b89 编写于 作者: Y Yu Kun

Merge remote-tracking branch 'upstream/branch-0.4.0' into branch-0.4.0


Former-commit-id: 927a963076107b8e63e59e88b583c5e599979036
......@@ -85,6 +85,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-460 - Put transport speed as weight when choosing neighbour to execute task
- MS-459 - Add cache for pick function in tasktable
- MS-482 - Change search stream transport to unary in grpc
- MS-487 - Define metric type in CreateTable
- MS-488 - Improve code format in scheduler
## New Feature
- MS-343 - Implement ResourceMgr
......
......@@ -261,7 +261,8 @@ else()
message(STATUS ${FAISS_SOURCE_URL})
endif()
# set(FAISS_MD5 "a589663865a8558205533c8ac414278c")
set(FAISS_MD5 "57da9c4f599cc8fa4260488b1c96e1cc")
# set(FAISS_MD5 "57da9c4f599cc8fa4260488b1c96e1cc") # commit-id 6dbdf75987c34a2c853bd172ea0d384feea8358c
set(FAISS_MD5 "21deb1c708490ca40ecb899122c01403") # commit-id 643e48f479637fd947e7b93fa4ca72b38ecc9a39
if(DEFINED ENV{KNOWHERE_ARROW_URL})
set(ARROW_SOURCE_URL "$ENV{KNOWHERE_ARROW_URL}")
......
......@@ -250,7 +250,9 @@ Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index)
}
//step 2: update index info
if(!utils::IsSameIndex(old_index, index)) {
TableIndex new_index = index;
new_index.metric_type_ = old_index.metric_type_;//dont change metric type, it was defined by CreateTable
if(!utils::IsSameIndex(old_index, new_index)) {
DropIndex(table_id);
status = meta_ptr_->UpdateTableIndexParam(table_id, index);
......
......@@ -301,7 +301,8 @@ Status ExecutionEngineImpl::Search(long n,
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe}});
auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}};
auto ec = index_->Search(n, data, distances, labels, cfg);
if (ec != server::KNOWHERE_SUCCESS) {
ENGINE_LOG_ERROR << "Search error";
return Status::Error("Search: Search Error");
......
......@@ -707,7 +707,7 @@ Status SqliteMetaImpl::FilesToSearch(const std::string &table_id,
files[table_file.date_].push_back(table_file);
}
if(files.empty()) {
std::cout << "ERROR" << std::endl;
ENGINE_LOG_ERROR << "No file to search for table: " << table_id;
}
} catch (std::exception &e) {
return HandleException("Encounter exception when iterate index files", e);
......
......@@ -387,6 +387,7 @@ const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_milvus_2eproto::offsets[] PROT
PROTOBUF_FIELD_OFFSET(::milvus::grpc::TableSchema, table_name_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::TableSchema, dimension_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::TableSchema, index_file_size_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::TableSchema, metric_type_),
~0u, // no _has_bits_
PROTOBUF_FIELD_OFFSET(::milvus::grpc::Range, _internal_metadata_),
~0u, // no _extensions_
......@@ -486,7 +487,6 @@ const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_milvus_2eproto::offsets[] PROT
~0u, // no _weak_field_map_
PROTOBUF_FIELD_OFFSET(::milvus::grpc::Index, index_type_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::Index, nlist_),
PROTOBUF_FIELD_OFFSET(::milvus::grpc::Index, metric_type_),
~0u, // no _has_bits_
PROTOBUF_FIELD_OFFSET(::milvus::grpc::IndexParam, _internal_metadata_),
~0u, // no _extensions_
......@@ -505,20 +505,20 @@ const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_milvus_2eproto::offsets[] PROT
static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = {
{ 0, -1, sizeof(::milvus::grpc::TableName)},
{ 7, -1, sizeof(::milvus::grpc::TableSchema)},
{ 15, -1, sizeof(::milvus::grpc::Range)},
{ 22, -1, sizeof(::milvus::grpc::RowRecord)},
{ 28, -1, sizeof(::milvus::grpc::InsertParam)},
{ 36, -1, sizeof(::milvus::grpc::VectorIds)},
{ 43, -1, sizeof(::milvus::grpc::SearchParam)},
{ 53, -1, sizeof(::milvus::grpc::SearchInFilesParam)},
{ 60, -1, sizeof(::milvus::grpc::QueryResult)},
{ 67, -1, sizeof(::milvus::grpc::TopKQueryResult)},
{ 73, -1, sizeof(::milvus::grpc::TopKQueryResultList)},
{ 80, -1, sizeof(::milvus::grpc::StringReply)},
{ 87, -1, sizeof(::milvus::grpc::BoolReply)},
{ 94, -1, sizeof(::milvus::grpc::TableRowCount)},
{ 101, -1, sizeof(::milvus::grpc::Command)},
{ 107, -1, sizeof(::milvus::grpc::Index)},
{ 16, -1, sizeof(::milvus::grpc::Range)},
{ 23, -1, sizeof(::milvus::grpc::RowRecord)},
{ 29, -1, sizeof(::milvus::grpc::InsertParam)},
{ 37, -1, sizeof(::milvus::grpc::VectorIds)},
{ 44, -1, sizeof(::milvus::grpc::SearchParam)},
{ 54, -1, sizeof(::milvus::grpc::SearchInFilesParam)},
{ 61, -1, sizeof(::milvus::grpc::QueryResult)},
{ 68, -1, sizeof(::milvus::grpc::TopKQueryResult)},
{ 74, -1, sizeof(::milvus::grpc::TopKQueryResultList)},
{ 81, -1, sizeof(::milvus::grpc::StringReply)},
{ 88, -1, sizeof(::milvus::grpc::BoolReply)},
{ 95, -1, sizeof(::milvus::grpc::TableRowCount)},
{ 102, -1, sizeof(::milvus::grpc::Command)},
{ 108, -1, sizeof(::milvus::grpc::Index)},
{ 115, -1, sizeof(::milvus::grpc::IndexParam)},
{ 122, -1, sizeof(::milvus::grpc::DeleteByRangeParam)},
};
......@@ -547,37 +547,37 @@ static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] =
const char descriptor_table_protodef_milvus_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) =
"\n\014milvus.proto\022\013milvus.grpc\032\014status.prot"
"o\"D\n\tTableName\022#\n\006status\030\001 \001(\0132\023.milvus."
"grpc.Status\022\022\n\ntable_name\030\002 \001(\t\"e\n\013Table"
"grpc.Status\022\022\n\ntable_name\030\002 \001(\t\"z\n\013Table"
"Schema\022*\n\ntable_name\030\001 \001(\0132\026.milvus.grpc"
".TableName\022\021\n\tdimension\030\002 \001(\003\022\027\n\017index_f"
"ile_size\030\003 \001(\003\"/\n\005Range\022\023\n\013start_value\030\001"
" \001(\t\022\021\n\tend_value\030\002 \001(\t\" \n\tRowRecord\022\023\n\013"
"vector_data\030\001 \003(\002\"i\n\013InsertParam\022\022\n\ntabl"
"e_name\030\001 \001(\t\0220\n\020row_record_array\030\002 \003(\0132\026"
".milvus.grpc.RowRecord\022\024\n\014row_id_array\030\003"
" \003(\003\"I\n\tVectorIds\022#\n\006status\030\001 \001(\0132\023.milv"
"us.grpc.Status\022\027\n\017vector_id_array\030\002 \003(\003\""
"\242\001\n\013SearchParam\022\022\n\ntable_name\030\001 \001(\t\0222\n\022q"
"uery_record_array\030\002 \003(\0132\026.milvus.grpc.Ro"
"wRecord\022-\n\021query_range_array\030\003 \003(\0132\022.mil"
"vus.grpc.Range\022\014\n\004topk\030\004 \001(\003\022\016\n\006nprobe\030\005"
" \001(\003\"[\n\022SearchInFilesParam\022\025\n\rfile_id_ar"
"ray\030\001 \003(\t\022.\n\014search_param\030\002 \001(\0132\030.milvus"
".grpc.SearchParam\"+\n\013QueryResult\022\n\n\002id\030\001"
" \001(\003\022\020\n\010distance\030\002 \001(\001\"H\n\017TopKQueryResul"
"t\0225\n\023query_result_arrays\030\001 \003(\0132\030.milvus."
"grpc.QueryResult\"s\n\023TopKQueryResultList\022"
"#\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\0227\n\021"
"topk_query_result\030\002 \003(\0132\034.milvus.grpc.To"
"pKQueryResult\"H\n\013StringReply\022#\n\006status\030\001"
" \001(\0132\023.milvus.grpc.Status\022\024\n\014string_repl"
"y\030\002 \001(\t\"D\n\tBoolReply\022#\n\006status\030\001 \001(\0132\023.m"
"ilvus.grpc.Status\022\022\n\nbool_reply\030\002 \001(\010\"M\n"
"\rTableRowCount\022#\n\006status\030\001 \001(\0132\023.milvus."
"grpc.Status\022\027\n\017table_row_count\030\002 \001(\003\"\026\n\007"
"Command\022\013\n\003cmd\030\001 \001(\t\"\?\n\005Index\022\022\n\nindex_t"
"ype\030\001 \001(\005\022\r\n\005nlist\030\002 \001(\005\022\023\n\013metric_type\030"
"\003 \001(\005\"[\n\nIndexParam\022*\n\ntable_name\030\001 \001(\0132"
"ile_size\030\003 \001(\003\022\023\n\013metric_type\030\004 \001(\005\"/\n\005R"
"ange\022\023\n\013start_value\030\001 \001(\t\022\021\n\tend_value\030\002"
" \001(\t\" \n\tRowRecord\022\023\n\013vector_data\030\001 \003(\002\"i"
"\n\013InsertParam\022\022\n\ntable_name\030\001 \001(\t\0220\n\020row"
"_record_array\030\002 \003(\0132\026.milvus.grpc.RowRec"
"ord\022\024\n\014row_id_array\030\003 \003(\003\"I\n\tVectorIds\022#"
"\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\027\n\017v"
"ector_id_array\030\002 \003(\003\"\242\001\n\013SearchParam\022\022\n\n"
"table_name\030\001 \001(\t\0222\n\022query_record_array\030\002"
" \003(\0132\026.milvus.grpc.RowRecord\022-\n\021query_ra"
"nge_array\030\003 \003(\0132\022.milvus.grpc.Range\022\014\n\004t"
"opk\030\004 \001(\003\022\016\n\006nprobe\030\005 \001(\003\"[\n\022SearchInFil"
"esParam\022\025\n\rfile_id_array\030\001 \003(\t\022.\n\014search"
"_param\030\002 \001(\0132\030.milvus.grpc.SearchParam\"+"
"\n\013QueryResult\022\n\n\002id\030\001 \001(\003\022\020\n\010distance\030\002 "
"\001(\001\"H\n\017TopKQueryResult\0225\n\023query_result_a"
"rrays\030\001 \003(\0132\030.milvus.grpc.QueryResult\"s\n"
"\023TopKQueryResultList\022#\n\006status\030\001 \001(\0132\023.m"
"ilvus.grpc.Status\0227\n\021topk_query_result\030\002"
" \003(\0132\034.milvus.grpc.TopKQueryResult\"H\n\013St"
"ringReply\022#\n\006status\030\001 \001(\0132\023.milvus.grpc."
"Status\022\024\n\014string_reply\030\002 \001(\t\"D\n\tBoolRepl"
"y\022#\n\006status\030\001 \001(\0132\023.milvus.grpc.Status\022\022"
"\n\nbool_reply\030\002 \001(\010\"M\n\rTableRowCount\022#\n\006s"
"tatus\030\001 \001(\0132\023.milvus.grpc.Status\022\027\n\017tabl"
"e_row_count\030\002 \001(\003\"\026\n\007Command\022\013\n\003cmd\030\001 \001("
"\t\"*\n\005Index\022\022\n\nindex_type\030\001 \001(\005\022\r\n\005nlist\030"
"\002 \001(\005\"[\n\nIndexParam\022*\n\ntable_name\030\001 \001(\0132"
"\026.milvus.grpc.TableName\022!\n\005index\030\002 \001(\0132\022"
".milvus.grpc.Index\"K\n\022DeleteByRangeParam"
"\022!\n\005range\030\001 \001(\0132\022.milvus.grpc.Range\022\022\n\nt"
......@@ -1010,16 +1010,16 @@ TableSchema::TableSchema(const TableSchema& from)
table_name_ = nullptr;
}
::memcpy(&dimension_, &from.dimension_,
static_cast<size_t>(reinterpret_cast<char*>(&index_file_size_) -
reinterpret_cast<char*>(&dimension_)) + sizeof(index_file_size_));
static_cast<size_t>(reinterpret_cast<char*>(&metric_type_) -
reinterpret_cast<char*>(&dimension_)) + sizeof(metric_type_));
// @@protoc_insertion_point(copy_constructor:milvus.grpc.TableSchema)
}
void TableSchema::SharedCtor() {
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TableSchema_milvus_2eproto.base);
::memset(&table_name_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&index_file_size_) -
reinterpret_cast<char*>(&table_name_)) + sizeof(index_file_size_));
reinterpret_cast<char*>(&metric_type_) -
reinterpret_cast<char*>(&table_name_)) + sizeof(metric_type_));
}
TableSchema::~TableSchema() {
......@@ -1051,8 +1051,8 @@ void TableSchema::Clear() {
}
table_name_ = nullptr;
::memset(&dimension_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&index_file_size_) -
reinterpret_cast<char*>(&dimension_)) + sizeof(index_file_size_));
reinterpret_cast<char*>(&metric_type_) -
reinterpret_cast<char*>(&dimension_)) + sizeof(metric_type_));
_internal_metadata_.Clear();
}
......@@ -1085,6 +1085,13 @@ const char* TableSchema::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID
CHK_(ptr);
} else goto handle_unusual;
continue;
// int32 metric_type = 4;
case 4:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32)) {
metric_type_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint(&ptr);
CHK_(ptr);
} else goto handle_unusual;
continue;
default: {
handle_unusual:
if ((tag & 7) == 4 || tag == 0) {
......@@ -1152,6 +1159,19 @@ bool TableSchema::MergePartialFromCodedStream(
break;
}
// int32 metric_type = 4;
case 4: {
if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (32 & 0xFF)) {
DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadPrimitive<
::PROTOBUF_NAMESPACE_ID::int32, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_INT32>(
input, &metric_type_)));
} else {
goto handle_unusual;
}
break;
}
default: {
handle_unusual:
if (tag == 0) {
......@@ -1195,6 +1215,11 @@ void TableSchema::SerializeWithCachedSizes(
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64(3, this->index_file_size(), output);
}
// int32 metric_type = 4;
if (this->metric_type() != 0) {
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32(4, this->metric_type(), output);
}
if (_internal_metadata_.have_unknown_fields()) {
::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SerializeUnknownFields(
_internal_metadata_.unknown_fields(), output);
......@@ -1225,6 +1250,11 @@ void TableSchema::SerializeWithCachedSizes(
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->index_file_size(), target);
}
// int32 metric_type = 4;
if (this->metric_type() != 0) {
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->metric_type(), target);
}
if (_internal_metadata_.have_unknown_fields()) {
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SerializeUnknownFieldsToArray(
_internal_metadata_.unknown_fields(), target);
......@@ -1267,6 +1297,13 @@ size_t TableSchema::ByteSizeLong() const {
this->index_file_size());
}
// int32 metric_type = 4;
if (this->metric_type() != 0) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->metric_type());
}
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
SetCachedSize(cached_size);
return total_size;
......@@ -1303,6 +1340,9 @@ void TableSchema::MergeFrom(const TableSchema& from) {
if (from.index_file_size() != 0) {
set_index_file_size(from.index_file_size());
}
if (from.metric_type() != 0) {
set_metric_type(from.metric_type());
}
}
void TableSchema::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) {
......@@ -1329,6 +1369,7 @@ void TableSchema::InternalSwap(TableSchema* other) {
swap(table_name_, other->table_name_);
swap(dimension_, other->dimension_);
swap(index_file_size_, other->index_file_size_);
swap(metric_type_, other->metric_type_);
}
::PROTOBUF_NAMESPACE_ID::Metadata TableSchema::GetMetadata() const {
......@@ -5606,15 +5647,15 @@ Index::Index(const Index& from)
_internal_metadata_(nullptr) {
_internal_metadata_.MergeFrom(from._internal_metadata_);
::memcpy(&index_type_, &from.index_type_,
static_cast<size_t>(reinterpret_cast<char*>(&metric_type_) -
reinterpret_cast<char*>(&index_type_)) + sizeof(metric_type_));
static_cast<size_t>(reinterpret_cast<char*>(&nlist_) -
reinterpret_cast<char*>(&index_type_)) + sizeof(nlist_));
// @@protoc_insertion_point(copy_constructor:milvus.grpc.Index)
}
void Index::SharedCtor() {
::memset(&index_type_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&metric_type_) -
reinterpret_cast<char*>(&index_type_)) + sizeof(metric_type_));
reinterpret_cast<char*>(&nlist_) -
reinterpret_cast<char*>(&index_type_)) + sizeof(nlist_));
}
Index::~Index() {
......@@ -5641,8 +5682,8 @@ void Index::Clear() {
(void) cached_has_bits;
::memset(&index_type_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&metric_type_) -
reinterpret_cast<char*>(&index_type_)) + sizeof(metric_type_));
reinterpret_cast<char*>(&nlist_) -
reinterpret_cast<char*>(&index_type_)) + sizeof(nlist_));
_internal_metadata_.Clear();
}
......@@ -5668,13 +5709,6 @@ const char* Index::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::inte
CHK_(ptr);
} else goto handle_unusual;
continue;
// int32 metric_type = 3;
case 3:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) {
metric_type_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint(&ptr);
CHK_(ptr);
} else goto handle_unusual;
continue;
default: {
handle_unusual:
if ((tag & 7) == 4 || tag == 0) {
......@@ -5731,19 +5765,6 @@ bool Index::MergePartialFromCodedStream(
break;
}
// int32 metric_type = 3;
case 3: {
if (static_cast< ::PROTOBUF_NAMESPACE_ID::uint8>(tag) == (24 & 0xFF)) {
DO_((::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::ReadPrimitive<
::PROTOBUF_NAMESPACE_ID::int32, ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_INT32>(
input, &metric_type_)));
} else {
goto handle_unusual;
}
break;
}
default: {
handle_unusual:
if (tag == 0) {
......@@ -5781,11 +5802,6 @@ void Index::SerializeWithCachedSizes(
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32(2, this->nlist(), output);
}
// int32 metric_type = 3;
if (this->metric_type() != 0) {
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32(3, this->metric_type(), output);
}
if (_internal_metadata_.have_unknown_fields()) {
::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SerializeUnknownFields(
_internal_metadata_.unknown_fields(), output);
......@@ -5809,11 +5825,6 @@ void Index::SerializeWithCachedSizes(
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(2, this->nlist(), target);
}
// int32 metric_type = 3;
if (this->metric_type() != 0) {
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(3, this->metric_type(), target);
}
if (_internal_metadata_.have_unknown_fields()) {
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SerializeUnknownFieldsToArray(
_internal_metadata_.unknown_fields(), target);
......@@ -5849,13 +5860,6 @@ size_t Index::ByteSizeLong() const {
this->nlist());
}
// int32 metric_type = 3;
if (this->metric_type() != 0) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->metric_type());
}
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
SetCachedSize(cached_size);
return total_size;
......@@ -5889,9 +5893,6 @@ void Index::MergeFrom(const Index& from) {
if (from.nlist() != 0) {
set_nlist(from.nlist());
}
if (from.metric_type() != 0) {
set_metric_type(from.metric_type());
}
}
void Index::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) {
......@@ -5917,7 +5918,6 @@ void Index::InternalSwap(Index* other) {
_internal_metadata_.Swap(&other->_internal_metadata_);
swap(index_type_, other->index_type_);
swap(nlist_, other->nlist_);
swap(metric_type_, other->metric_type_);
}
::PROTOBUF_NAMESPACE_ID::Metadata Index::GetMetadata() const {
......
......@@ -401,6 +401,7 @@ class TableSchema :
kTableNameFieldNumber = 1,
kDimensionFieldNumber = 2,
kIndexFileSizeFieldNumber = 3,
kMetricTypeFieldNumber = 4,
};
// .milvus.grpc.TableName table_name = 1;
bool has_table_name() const;
......@@ -420,6 +421,11 @@ class TableSchema :
::PROTOBUF_NAMESPACE_ID::int64 index_file_size() const;
void set_index_file_size(::PROTOBUF_NAMESPACE_ID::int64 value);
// int32 metric_type = 4;
void clear_metric_type();
::PROTOBUF_NAMESPACE_ID::int32 metric_type() const;
void set_metric_type(::PROTOBUF_NAMESPACE_ID::int32 value);
// @@protoc_insertion_point(class_scope:milvus.grpc.TableSchema)
private:
class _Internal;
......@@ -428,6 +434,7 @@ class TableSchema :
::milvus::grpc::TableName* table_name_;
::PROTOBUF_NAMESPACE_ID::int64 dimension_;
::PROTOBUF_NAMESPACE_ID::int64 index_file_size_;
::PROTOBUF_NAMESPACE_ID::int32 metric_type_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto;
};
......@@ -2466,7 +2473,6 @@ class Index :
enum : int {
kIndexTypeFieldNumber = 1,
kNlistFieldNumber = 2,
kMetricTypeFieldNumber = 3,
};
// int32 index_type = 1;
void clear_index_type();
......@@ -2478,11 +2484,6 @@ class Index :
::PROTOBUF_NAMESPACE_ID::int32 nlist() const;
void set_nlist(::PROTOBUF_NAMESPACE_ID::int32 value);
// int32 metric_type = 3;
void clear_metric_type();
::PROTOBUF_NAMESPACE_ID::int32 metric_type() const;
void set_metric_type(::PROTOBUF_NAMESPACE_ID::int32 value);
// @@protoc_insertion_point(class_scope:milvus.grpc.Index)
private:
class _Internal;
......@@ -2490,7 +2491,6 @@ class Index :
::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_;
::PROTOBUF_NAMESPACE_ID::int32 index_type_;
::PROTOBUF_NAMESPACE_ID::int32 nlist_;
::PROTOBUF_NAMESPACE_ID::int32 metric_type_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto;
};
......@@ -2975,6 +2975,20 @@ inline void TableSchema::set_index_file_size(::PROTOBUF_NAMESPACE_ID::int64 valu
// @@protoc_insertion_point(field_set:milvus.grpc.TableSchema.index_file_size)
}
// int32 metric_type = 4;
inline void TableSchema::clear_metric_type() {
metric_type_ = 0;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TableSchema::metric_type() const {
// @@protoc_insertion_point(field_get:milvus.grpc.TableSchema.metric_type)
return metric_type_;
}
inline void TableSchema::set_metric_type(::PROTOBUF_NAMESPACE_ID::int32 value) {
metric_type_ = value;
// @@protoc_insertion_point(field_set:milvus.grpc.TableSchema.metric_type)
}
// -------------------------------------------------------------------
// Range
......@@ -4030,20 +4044,6 @@ inline void Index::set_nlist(::PROTOBUF_NAMESPACE_ID::int32 value) {
// @@protoc_insertion_point(field_set:milvus.grpc.Index.nlist)
}
// int32 metric_type = 3;
inline void Index::clear_metric_type() {
metric_type_ = 0;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 Index::metric_type() const {
// @@protoc_insertion_point(field_get:milvus.grpc.Index.metric_type)
return metric_type_;
}
inline void Index::set_metric_type(::PROTOBUF_NAMESPACE_ID::int32 value) {
metric_type_ = value;
// @@protoc_insertion_point(field_set:milvus.grpc.Index.metric_type)
}
// -------------------------------------------------------------------
// IndexParam
......
......@@ -19,6 +19,7 @@ message TableSchema {
TableName table_name = 1;
int64 dimension = 2;
int64 index_file_size = 3;
int32 metric_type = 4;
}
/**
......@@ -134,7 +135,6 @@ message Command {
message Index {
int32 index_type = 1;
int32 nlist = 2;
int32 metric_type = 3;
}
/**
......
......@@ -20,12 +20,12 @@ ShortestPath(const ResourcePtr &src,
std::vector<std::vector<std::string>> paths;
uint64_t num_of_resources = res_mgr->GetAllResouces().size();
uint64_t num_of_resources = res_mgr->GetAllResources().size();
std::unordered_map<uint64_t, std::string> id_name_map;
std::unordered_map<std::string, uint64_t> name_id_map;
for (uint64_t i = 0; i < num_of_resources; ++i) {
id_name_map.insert(std::make_pair(i, res_mgr->GetAllResouces().at(i)->Name()));
name_id_map.insert(std::make_pair(res_mgr->GetAllResouces().at(i)->Name(), i));
id_name_map.insert(std::make_pair(i, res_mgr->GetAllResources().at(i)->name()));
name_id_map.insert(std::make_pair(res_mgr->GetAllResources().at(i)->name(), i));
}
std::vector<std::vector<uint64_t> > dis_matrix;
......@@ -40,23 +40,23 @@ ShortestPath(const ResourcePtr &src,
std::vector<bool> vis(num_of_resources, false);
std::vector<uint64_t> dis(num_of_resources, MAXINT);
for (auto &res : res_mgr->GetAllResouces()) {
for (auto &res : res_mgr->GetAllResources()) {
auto cur_node = std::static_pointer_cast<Node>(res);
auto cur_neighbours = cur_node->GetNeighbours();
for (auto &neighbour : cur_neighbours) {
auto neighbour_res = std::static_pointer_cast<Resource>(neighbour.neighbour_node.lock());
dis_matrix[name_id_map.at(res->Name())][name_id_map.at(neighbour_res->Name())] =
dis_matrix[name_id_map.at(res->name())][name_id_map.at(neighbour_res->name())] =
neighbour.connection.transport_cost();
}
}
for (uint64_t i = 0; i < num_of_resources; ++i) {
dis[i] = dis_matrix[name_id_map.at(src->Name())][i];
dis[i] = dis_matrix[name_id_map.at(src->name())][i];
}
vis[name_id_map.at(src->Name())] = true;
vis[name_id_map.at(src->name())] = true;
std::vector<int64_t> parent(num_of_resources, -1);
for (uint64_t i = 0; i < num_of_resources; ++i) {
......@@ -71,7 +71,7 @@ ShortestPath(const ResourcePtr &src,
vis[temp] = true;
if (i == 0) {
parent[temp] = name_id_map.at(src->Name());
parent[temp] = name_id_map.at(src->name());
}
for (uint64_t j = 0; j < num_of_resources; ++j) {
......@@ -82,15 +82,15 @@ ShortestPath(const ResourcePtr &src,
}
}
int64_t parent_idx = parent[name_id_map.at(dest->Name())];
int64_t parent_idx = parent[name_id_map.at(dest->name())];
if (parent_idx != -1) {
path.push_back(dest->Name());
path.push_back(dest->name());
}
while (parent_idx != -1) {
path.push_back(id_name_map.at(parent_idx));
parent_idx = parent[parent_idx];
}
return dis[name_id_map.at(dest->Name())];
return dis[name_id_map.at(dest->name())];
}
}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <memory>
namespace zilliz {
namespace milvus {
namespace engine {
// dummy cache_mgr
class CacheMgr {
};
using CacheMgrPtr = std::shared_ptr<CacheMgr>;
}
}
}
......@@ -12,67 +12,31 @@ namespace zilliz {
namespace milvus {
namespace engine {
ResourceMgr::ResourceMgr()
: running_(false) {
}
uint64_t
ResourceMgr::GetNumOfComputeResource() {
uint64_t count = 0;
for (auto &res : resources_) {
if (res->HasExecutor()) {
++count;
}
}
return count;
}
std::vector<ResourcePtr>
ResourceMgr::GetComputeResource() {
std::vector<ResourcePtr > result;
void
ResourceMgr::Start() {
std::lock_guard<std::mutex> lck(resources_mutex_);
for (auto &resource : resources_) {
if (resource->HasExecutor()) {
result.emplace_back(resource);
}
}
return result;
}
uint64_t
ResourceMgr::GetNumGpuResource() const {
uint64_t num = 0;
for (auto &res : resources_) {
if (res->Type() == ResourceType::GPU) {
num++;
}
resource->Start();
}
return num;
running_ = true;
worker_thread_ = std::thread(&ResourceMgr::event_process, this);
}
ResourcePtr
ResourceMgr::GetResource(ResourceType type, uint64_t device_id) {
for (auto &resource : resources_) {
if (resource->Type() == type && resource->DeviceId() == device_id) {
return resource;
}
void
ResourceMgr::Stop() {
{
std::lock_guard<std::mutex> lock(event_mutex_);
running_ = false;
queue_.push(nullptr);
event_cv_.notify_one();
}
return nullptr;
}
worker_thread_.join();
ResourcePtr
ResourceMgr::GetResourceByName(std::string name) {
std::lock_guard<std::mutex> lck(resources_mutex_);
for (auto &resource : resources_) {
if (resource->Name() == name) {
return resource;
}
resource->Stop();
}
return nullptr;
}
std::vector<ResourcePtr>
ResourceMgr::GetAllResouces() {
return resources_;
}
ResourceWPtr
......@@ -85,75 +49,85 @@ ResourceMgr::Add(ResourcePtr &&resource) {
return ret;
}
if (resource->Type() == ResourceType::DISK) {
resource->RegisterSubscriber(std::bind(&ResourceMgr::post_event, this, std::placeholders::_1));
if (resource->type() == ResourceType::DISK) {
disk_resources_.emplace_back(ResourceWPtr(resource));
}
resources_.emplace_back(resource);
size_t index = resources_.size() - 1;
resource->RegisterSubscriber(std::bind(&ResourceMgr::PostEvent, this, std::placeholders::_1));
return ret;
}
void
ResourceMgr::Connect(const std::string &name1, const std::string &name2, Connection &connection) {
auto res1 = get_resource_by_name(name1);
auto res2 = get_resource_by_name(name2);
auto res1 = GetResource(name1);
auto res2 = GetResource(name2);
if (res1 && res2) {
res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection);
// TODO: enable when task balance supported
// res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection);
}
}
void
ResourceMgr::Connect(ResourceWPtr &res1, ResourceWPtr &res2, Connection &connection) {
if (auto observe_a = res1.lock()) {
if (auto observe_b = res2.lock()) {
observe_a->AddNeighbour(std::static_pointer_cast<Node>(observe_b), connection);
observe_b->AddNeighbour(std::static_pointer_cast<Node>(observe_a), connection);
}
}
ResourceMgr::Clear() {
std::lock_guard<std::mutex> lck(resources_mutex_);
disk_resources_.clear();
resources_.clear();
}
void
ResourceMgr::Start() {
std::lock_guard<std::mutex> lck(resources_mutex_);
std::vector<ResourcePtr>
ResourceMgr::GetComputeResource() {
std::vector<ResourcePtr> result;
for (auto &resource : resources_) {
resource->Start();
if (resource->HasExecutor()) {
result.emplace_back(resource);
}
}
running_ = true;
worker_thread_ = std::thread(&ResourceMgr::event_process, this);
return result;
}
void
ResourceMgr::Stop() {
{
std::lock_guard<std::mutex> lock(event_mutex_);
running_ = false;
queue_.push(nullptr);
event_cv_.notify_one();
ResourcePtr
ResourceMgr::GetResource(ResourceType type, uint64_t device_id) {
for (auto &resource : resources_) {
if (resource->type() == type && resource->device_id() == device_id) {
return resource;
}
}
worker_thread_.join();
return nullptr;
}
std::lock_guard<std::mutex> lck(resources_mutex_);
ResourcePtr
ResourceMgr::GetResource(const std::string &name) {
for (auto &resource : resources_) {
resource->Stop();
if (resource->name() == name) {
return resource;
}
}
return nullptr;
}
void
ResourceMgr::Clear() {
std::lock_guard<std::mutex> lck(resources_mutex_);
disk_resources_.clear();
resources_.clear();
uint64_t
ResourceMgr::GetNumOfComputeResource() {
uint64_t count = 0;
for (auto &res : resources_) {
if (res->HasExecutor()) {
++count;
}
}
return count;
}
void
ResourceMgr::PostEvent(const EventPtr &event) {
std::lock_guard<std::mutex> lock(event_mutex_);
queue_.emplace(event);
event_cv_.notify_one();
uint64_t
ResourceMgr::GetNumGpuResource() const {
uint64_t num = 0;
for (auto &res : resources_) {
if (res->type() == ResourceType::GPU) {
num++;
}
}
return num;
}
std::string
......@@ -180,14 +154,13 @@ ResourceMgr::DumpTaskTables() {
return ss.str();
}
ResourcePtr
ResourceMgr::get_resource_by_name(const std::string &name) {
for (auto &res : resources_) {
if (res->Name() == name) {
return res;
}
void
ResourceMgr::post_event(const EventPtr &event) {
{
std::lock_guard<std::mutex> lock(event_mutex_);
queue_.emplace(event);
}
return nullptr;
event_cv_.notify_one();
}
void
......@@ -203,8 +176,6 @@ ResourceMgr::event_process() {
break;
}
// ENGINE_LOG_DEBUG << "ResourceMgr process " << *event;
if (subscriber_) {
subscriber_(event);
}
......
......@@ -22,78 +22,63 @@ namespace engine {
class ResourceMgr {
public:
ResourceMgr();
ResourceMgr() = default;
public:
/******** Management Interface ********/
void
Start();
void
Stop();
ResourceWPtr
Add(ResourcePtr &&resource);
void
Connect(const std::string &res1, const std::string &res2, Connection &connection);
void
Clear();
inline void
RegisterSubscriber(std::function<void(EventPtr)> subscriber) {
subscriber_ = std::move(subscriber);
}
std::vector<ResourceWPtr> &
public:
/******** Management Interface ********/
inline std::vector<ResourceWPtr> &
GetDiskResources() {
return disk_resources_;
}
uint64_t
GetNumGpuResource() const;
// TODO: why return shared pointer
inline std::vector<ResourcePtr>
GetAllResources() {
return resources_;
}
std::vector<ResourcePtr>
GetComputeResource();
ResourcePtr
GetResource(ResourceType type, uint64_t device_id);
ResourcePtr
GetResourceByName(std::string name);
GetResource(const std::string &name);
std::vector<ResourcePtr>
GetAllResouces();
/*
* Return account of resource which enable executor;
*/
uint64_t
GetNumOfComputeResource();
std::vector<ResourcePtr>
GetComputeResource();
/*
* Add resource into Resource Management;
* Generate functions on events;
* Functions only modify bool variable, like event trigger;
*/
ResourceWPtr
Add(ResourcePtr &&resource);
void
Connect(const std::string &res1, const std::string &res2, Connection &connection);
/*
* Create connection between A and B;
*/
void
Connect(ResourceWPtr &res1, ResourceWPtr &res2, Connection &connection);
/*
* Synchronous start all resource;
* Last, start event process thread;
*/
void
Start();
void
Stop();
void
Clear();
void
PostEvent(const EventPtr &event);
uint64_t
GetNumGpuResource() const;
public:
// TODO: add stats interface(low)
public:
/******** Utlitity Functions ********/
/******** Utility Functions ********/
std::string
Dump();
......@@ -101,26 +86,26 @@ public:
DumpTaskTables();
private:
ResourcePtr
get_resource_by_name(const std::string &name);
void
post_event(const EventPtr &event);
void
event_process();
private:
std::queue<EventPtr> queue_;
std::function<void(EventPtr)> subscriber_ = nullptr;
bool running_;
bool running_ = false;
std::vector<ResourceWPtr> disk_resources_;
std::vector<ResourcePtr> resources_;
mutable std::mutex resources_mutex_;
std::thread worker_thread_;
std::queue<EventPtr> queue_;
std::function<void(EventPtr)> subscriber_ = nullptr;
std::mutex event_mutex_;
std::condition_variable event_cv_;
std::thread worker_thread_;
};
using ResourceMgrPtr = std::shared_ptr<ResourceMgr>;
......
......@@ -49,7 +49,7 @@ StartSchedulerService() {
enable_loader,
enable_executor));
if (res.lock()->Type() == ResourceType::GPU) {
if (res.lock()->type() == ResourceType::GPU) {
auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY, 300);
auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY, 300);
auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM, 2);
......
......@@ -143,7 +143,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
auto task = load_completed_event->task_table_item_->task;
// if this resource is disk, assign it to smallest cost resource
if (self->Type() == ResourceType::DISK) {
if (self->type() == ResourceType::DISK) {
// step 1: calculate shortest path per resource, from disk to compute resource
auto compute_resources = res_mgr_.lock()->GetComputeResource();
std::vector<std::vector<std::string>> paths;
......@@ -176,11 +176,11 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
task->path() = task_path;
}
if(self->Name() == task->path().Last()) {
if(self->name() == task->path().Last()) {
self->WakeupLoader();
} else {
auto next_res_name = task->path().Next();
auto next_res = res_mgr_.lock()->GetResourceByName(next_res_name);
auto next_res = res_mgr_.lock()->GetResource(next_res_name);
load_completed_event->task_table_item_->Move();
next_res->task_table().Put(task);
}
......
......@@ -6,6 +6,8 @@
#include "TaskTable.h"
#include "event/TaskTableUpdatedEvent.h"
#include "Utils.h"
#include <vector>
#include <sstream>
#include <ctime>
......@@ -15,14 +17,6 @@ namespace zilliz {
namespace milvus {
namespace engine {
uint64_t
get_now_timestamp() {
std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
return millis;
}
std::string
ToString(TaskTableItemState state) {
switch (state) {
......@@ -64,7 +58,7 @@ TaskTableItem::Load() {
if (state == TaskTableItemState::START) {
state = TaskTableItemState::LOADING;
lock.unlock();
timestamp.load = get_now_timestamp();
timestamp.load = get_current_timestamp();
return true;
}
return false;
......@@ -75,7 +69,7 @@ TaskTableItem::Loaded() {
if (state == TaskTableItemState::LOADING) {
state = TaskTableItemState::LOADED;
lock.unlock();
timestamp.loaded = get_now_timestamp();
timestamp.loaded = get_current_timestamp();
return true;
}
return false;
......@@ -86,7 +80,7 @@ TaskTableItem::Execute() {
if (state == TaskTableItemState::LOADED) {
state = TaskTableItemState::EXECUTING;
lock.unlock();
timestamp.execute = get_now_timestamp();
timestamp.execute = get_current_timestamp();
return true;
}
return false;
......@@ -97,8 +91,8 @@ TaskTableItem::Executed() {
if (state == TaskTableItemState::EXECUTING) {
state = TaskTableItemState::EXECUTED;
lock.unlock();
timestamp.executed = get_now_timestamp();
timestamp.finish = get_now_timestamp();
timestamp.executed = get_current_timestamp();
timestamp.finish = get_current_timestamp();
return true;
}
return false;
......@@ -109,7 +103,7 @@ TaskTableItem::Move() {
if (state == TaskTableItemState::LOADED) {
state = TaskTableItemState::MOVING;
lock.unlock();
timestamp.move = get_now_timestamp();
timestamp.move = get_current_timestamp();
return true;
}
return false;
......@@ -120,8 +114,8 @@ TaskTableItem::Moved() {
if (state == TaskTableItemState::MOVING) {
state = TaskTableItemState::MOVED;
lock.unlock();
timestamp.moved = get_now_timestamp();
timestamp.finish = get_now_timestamp();
timestamp.moved = get_current_timestamp();
timestamp.finish = get_current_timestamp();
return true;
}
return false;
......@@ -177,7 +171,7 @@ TaskTable::Put(TaskPtr task) {
item->id = id_++;
item->task = std::move(task);
item->state = TaskTableItemState::START;
item->timestamp.start = get_now_timestamp();
item->timestamp.start = get_current_timestamp();
table_.push_back(item);
if (subscriber_) {
subscriber_();
......@@ -192,7 +186,7 @@ TaskTable::Put(std::vector<TaskPtr> &tasks) {
item->id = id_++;
item->task = std::move(task);
item->state = TaskTableItemState::START;
item->timestamp.start = get_now_timestamp();
item->timestamp.start = get_current_timestamp();
table_.push_back(item);
}
if (subscriber_) {
......
......@@ -40,20 +40,17 @@ struct TaskTimestamp {
};
struct TaskTableItem {
TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex(), priority(0) {}
TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex() {}
TaskTableItem(const TaskTableItem &src)
: id(src.id), state(src.state), mutex(), priority(src.priority) {}
: id(src.id), state(src.state), mutex() {}
uint64_t id; // auto increment from 0;
// TODO: add tag into task
TaskPtr task; // the task;
TaskTableItemState state; // the state;
std::mutex mutex;
TaskTimestamp timestamp;
uint8_t priority; // just a number, meaningless;
bool
IsFinish();
......@@ -113,7 +110,7 @@ public:
Get(uint64_t index);
/*
* TODO
* TODO(wxyu): BIG GC
* Remove sequence task which is DONE or MOVED from front;
* Called by ?
*/
......@@ -135,6 +132,7 @@ public:
Size() {
return table_.size();
}
public:
TaskTableItemPtr &
operator[](uint64_t index) {
......@@ -225,7 +223,6 @@ public:
Dump();
private:
// TODO: map better ?
std::uint64_t id_ = 0;
mutable std::mutex id_mutex_;
std::deque<TaskTableItemPtr> table_;
......
......@@ -4,16 +4,17 @@
* Proprietary and confidential.
******************************************************************************/
#include <chrono>
#include "Utils.h"
#include <chrono>
namespace zilliz {
namespace milvus {
namespace engine {
uint64_t
get_current_timestamp()
{
get_current_timestamp() {
std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
......
......@@ -3,6 +3,7 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include <cstdint>
......
......@@ -15,6 +15,7 @@ namespace engine {
class Connection {
public:
// TODO: update construct function, speed: double->uint64_t
Connection(std::string name, double speed)
: name_(std::move(name)), speed_(speed) {}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <memory>
namespace zilliz {
namespace milvus {
namespace engine {
class RegisterHandler {
public:
virtual void Exec() = 0;
};
using RegisterHandlerPtr = std::shared_ptr<RegisterHandler>;
}
}
}
\ No newline at end of file
......@@ -12,7 +12,8 @@ namespace zilliz {
namespace milvus {
namespace engine {
std::ostream &operator<<(std::ostream &out, const Resource &resource) {
std::ostream &
operator<<(std::ostream &out, const Resource &resource) {
out << resource.Dump();
return out;
}
......@@ -25,11 +26,9 @@ Resource::Resource(std::string name,
: name_(std::move(name)),
type_(type),
device_id_(device_id),
running_(false),
enable_loader_(enable_loader),
enable_executor_(enable_executor),
load_flag_(false),
exec_flag_(false) {
enable_executor_(enable_executor) {
// register subscriber in tasktable
task_table_.RegisterSubscriber([&] {
if (subscriber_) {
auto event = std::make_shared<TaskTableUpdatedEvent>(shared_from_this());
......@@ -38,7 +37,8 @@ Resource::Resource(std::string name,
});
}
void Resource::Start() {
void
Resource::Start() {
running_ = true;
if (enable_loader_) {
loader_thread_ = std::thread(&Resource::loader_function, this);
......@@ -48,7 +48,8 @@ void Resource::Start() {
}
}
void Resource::Stop() {
void
Resource::Stop() {
running_ = false;
if (enable_loader_) {
WakeupLoader();
......@@ -60,11 +61,8 @@ void Resource::Stop() {
}
}
TaskTable &Resource::task_table() {
return task_table_;
}
void Resource::WakeupLoader() {
void
Resource::WakeupLoader() {
{
std::lock_guard<std::mutex> lock(load_mutex_);
load_flag_ = true;
......@@ -72,7 +70,8 @@ void Resource::WakeupLoader() {
load_cv_.notify_one();
}
void Resource::WakeupExecutor() {
void
Resource::WakeupExecutor() {
{
std::lock_guard<std::mutex> lock(exec_mutex_);
exec_flag_ = true;
......@@ -80,6 +79,15 @@ void Resource::WakeupExecutor() {
exec_cv_.notify_one();
}
uint64_t
Resource::NumOfTaskToExec() {
uint64_t count = 0;
for (auto &task : task_table_) {
if (task->state == TaskTableItemState::LOADED) ++count;
}
return count;
}
TaskTableItemPtr Resource::pick_task_load() {
auto indexes = task_table_.PickToLoad(10);
for (auto index : indexes) {
......@@ -156,11 +164,6 @@ void Resource::executor_function() {
}
}
RegisterHandlerPtr Resource::GetRegisterFunc(const RegisterType &type) {
// construct object each time.
return register_table_[type]();
}
}
}
}
\ No newline at end of file
......@@ -21,7 +21,6 @@
#include "../task/Task.h"
#include "Connection.h"
#include "Node.h"
#include "RegisterHandler.h"
namespace zilliz {
......@@ -35,13 +34,6 @@ enum class ResourceType {
GPU = 2
};
enum class RegisterType {
START_UP,
ON_FINISH_TASK,
ON_COPY_COMPLETED,
ON_TASK_TABLE_UPDATED,
};
class Resource : public Node, public std::enable_shared_from_this<Resource> {
public:
/*
......@@ -68,56 +60,51 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
void
WakeupExecutor();
public:
template<typename T>
void Register_T(const RegisterType &type) {
register_table_.emplace(type, [] { return std::make_shared<T>(); });
}
RegisterHandlerPtr
GetRegisterFunc(const RegisterType &type);
inline void
RegisterSubscriber(std::function<void(EventPtr)> subscriber) {
subscriber_ = std::move(subscriber);
}
inline virtual std::string
Dump() const {
return "<Resource>";
}
public:
inline std::string
Name() const {
name() const {
return name_;
}
inline ResourceType
Type() const {
type() const {
return type_;
}
inline uint64_t
DeviceId() {
device_id() const {
return device_id_;
}
// TODO: better name?
TaskTable &
task_table() {
return task_table_;
}
public:
inline bool
HasLoader() {
HasLoader() const {
return enable_loader_;
}
// TODO: better name?
inline bool
HasExecutor() {
HasExecutor() const {
return enable_executor_;
}
// TODO: const
uint64_t
NumOfTaskToExec() {
uint64_t count = 0;
for (auto &task : task_table_) {
if (task->state == TaskTableItemState::LOADED) ++count;
}
return count;
}
NumOfTaskToExec();
// TODO: need double ?
inline uint64_t
......@@ -130,14 +117,6 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
return total_task_;
}
TaskTable &
task_table();
inline virtual std::string
Dump() const {
return "<Resource>";
}
friend std::ostream &operator<<(std::ostream &out, const Resource &resource);
protected:
......@@ -198,6 +177,7 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
protected:
uint64_t device_id_;
std::string name_;
private:
ResourceType type_;
......@@ -206,17 +186,16 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
uint64_t total_cost_ = 0;
uint64_t total_task_ = 0;
std::map<RegisterType, std::function<RegisterHandlerPtr()>> register_table_;
std::function<void(EventPtr)> subscriber_ = nullptr;
bool running_;
bool running_ = false;
bool enable_loader_ = true;
bool enable_executor_ = true;
std::thread loader_thread_;
std::thread executor_thread_;
bool load_flag_;
bool exec_flag_;
bool load_flag_ = false;
bool exec_flag_ = false;
std::mutex load_mutex_;
std::mutex exec_mutex_;
std::condition_variable load_cv_;
......
......@@ -24,12 +24,6 @@ XDeleteTask::Execute() {
delete_context_ptr_->ResourceDone();
}
TaskPtr
XDeleteTask::Clone() {
auto task = std::make_shared<XDeleteTask>(delete_context_ptr_);
return task;
}
}
}
}
......@@ -24,9 +24,6 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
public:
DeleteContextPtr delete_context_ptr_;
};
......
......@@ -193,16 +193,6 @@ XSearchTask::Execute() {
index_engine_ = nullptr;
}
TaskPtr
XSearchTask::Clone() {
auto ret = std::make_shared<XSearchTask>(file_);
ret->index_id_ = index_id_;
ret->index_engine_ = index_engine_->Clone();
ret->search_contexts_ = search_contexts_;
ret->metric_l2 = metric_l2;
return ret;
}
Status XSearchTask::ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
......
......@@ -23,9 +23,6 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
public:
static Status ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
......
......@@ -68,14 +68,9 @@ public:
virtual void
Execute() = 0;
// TODO: dont use this method to support task move
virtual TaskPtr
Clone() = 0;
public:
Path task_path_;
std::vector<SearchContextPtr> search_contexts_;
ScheduleTaskPtr task_;
TaskType type_;
TaskLabelPtr label_ = nullptr;
};
......
......@@ -21,7 +21,6 @@ TaskConvert(const ScheduleTaskPtr &schedule_task) {
auto task = std::make_shared<XSearchTask>(load_task->file_);
task->label() = std::make_shared<DefaultLabel>();
task->search_contexts_ = load_task->search_contexts_;
task->task_ = schedule_task;
return task;
}
case ScheduleTaskType::kDelete: {
......
......@@ -27,15 +27,6 @@ TestTask::Execute() {
done_ = true;
}
TaskPtr
TestTask::Clone() {
TableFileSchemaPtr dummy = nullptr;
auto ret = std::make_shared<TestTask>(dummy);
ret->load_count_ = load_count_;
ret->exec_count_ = exec_count_;
return ret;
}
void
TestTask::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
......
......@@ -23,9 +23,6 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
void
Wait();
......
......@@ -25,6 +25,7 @@ public:
}
protected:
explicit
TaskLabel(TaskLabelType type) : type_(type) {}
private:
......
......@@ -96,6 +96,7 @@ TableSchema BuildTableSchema() {
tb_schema.table_name = TABLE_NAME;
tb_schema.dimension = TABLE_DIMENSION;
tb_schema.index_file_size = TABLE_INDEX_FILE_SIZE;
tb_schema.metric_type = MetricType::L2;
return tb_schema;
}
......@@ -291,7 +292,6 @@ ClientTest::Test(const std::string& address, const std::string& port) {
index.table_name = TABLE_NAME;
index.index_type = IndexType::gpu_ivfflat;
index.nlist = 16384;
index.metric_type = 1;
Status stat = conn->CreateIndex(index);
std::cout << "CreateIndex function call status: " << stat.ToString() << std::endl;
......
......@@ -84,6 +84,7 @@ ClientProxy::CreateTable(const TableSchema &param) {
schema.mutable_table_name()->set_table_name(param.table_name);
schema.set_dimension(param.dimension);
schema.set_index_file_size(param.index_file_size);
schema.set_metric_type((int32_t)param.metric_type);
return client_ptr_->CreateTable(schema);
} catch (std::exception &ex) {
......@@ -116,11 +117,9 @@ ClientProxy::CreateIndex(const IndexParam &index_param) {
try {
//TODO:add index params
::milvus::grpc::IndexParam grpc_index_param;
grpc_index_param.mutable_table_name()->set_table_name(
index_param.table_name);
grpc_index_param.mutable_table_name()->set_table_name(index_param.table_name);
grpc_index_param.mutable_index()->set_index_type((int32_t)index_param.index_type);
grpc_index_param.mutable_index()->set_nlist(index_param.nlist);
grpc_index_param.mutable_index()->set_metric_type(index_param.metric_type);
return client_ptr_->CreateIndex(grpc_index_param);
} catch (std::exception &ex) {
......@@ -273,6 +272,7 @@ ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_sch
table_schema.table_name = grpc_schema.table_name().table_name();
table_schema.dimension = grpc_schema.dimension();
table_schema.index_file_size = grpc_schema.index_file_size();
table_schema.metric_type = (MetricType)grpc_schema.metric_type();
return status;
} catch (std::exception &ex) {
......@@ -378,7 +378,6 @@ ClientProxy::DescribeIndex(const std::string &table_name, IndexParam &index_para
Status status = client_ptr_->DescribeIndex(grpc_table_name, grpc_index_param);
index_param.index_type = (IndexType)(grpc_index_param.mutable_index()->index_type());
index_param.nlist = grpc_index_param.mutable_index()->nlist();
index_param.metric_type = grpc_index_param.mutable_index()->metric_type();
return status;
......
......@@ -22,6 +22,11 @@ enum class IndexType {
mix_nsg,
};
enum class MetricType {
L2 = 1,
IP = 2,
};
/**
* @brief Connect API parameter
*/
......@@ -37,6 +42,7 @@ struct TableSchema {
std::string table_name; ///< Table name
int64_t dimension = 0; ///< Vector dimension, must be a positive value
int64_t index_file_size = 0; ///< Index file size, must be a positive value
MetricType metric_type = MetricType::L2; ///< Index metric type
};
/**
......@@ -77,7 +83,6 @@ struct IndexParam {
std::string table_name;
IndexType index_type;
int32_t nlist;
int32_t metric_type;
};
/**
......
......@@ -145,11 +145,17 @@ CreateTableTask::OnExecute() {
return SetError(res, "Invalid index file size: " + std::to_string(schema_->index_file_size()));
}
res = ValidationUtil::ValidateTableIndexMetricType(schema_->metric_type());
if(res != SERVER_SUCCESS) {
return SetError(res, "Invalid index metric type: " + std::to_string(schema_->metric_type()));
}
//step 2: construct table schema
engine::meta::TableSchema table_info;
table_info.table_id_ = schema_->table_name().table_name();
table_info.dimension_ = (uint16_t) schema_->dimension();
table_info.index_file_size_ = schema_->index_file_size();
table_info.metric_type_ = schema_->metric_type();
//step 3: create table
engine::Status stat = DBWrapper::DB()->CreateTable(table_info);
......@@ -204,6 +210,7 @@ DescribeTableTask::OnExecute() {
schema_->mutable_table_name()->set_table_name(table_info.table_id_);
schema_->set_dimension(table_info.dimension_);
schema_->set_index_file_size(table_info.index_file_size_);
schema_->set_metric_type(table_info.metric_type_);
} catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
......@@ -262,16 +269,10 @@ CreateIndexTask::OnExecute() {
return SetError(res, "Invalid index nlist: " + std::to_string(grpc_index.nlist()));
}
res = ValidationUtil::ValidateTableIndexMetricType(grpc_index.metric_type());
if(res != SERVER_SUCCESS) {
return SetError(res, "Invalid index metric type: " + std::to_string(grpc_index.metric_type()));
}
//step 2: check table existence
engine::TableIndex index;
index.engine_type_ = grpc_index.index_type();
index.nlist_ = grpc_index.nlist();
index.metric_type_ = grpc_index.metric_type();
stat = DBWrapper::DB()->CreateIndex(table_name_, index);
if (!stat.ok()) {
return SetError(SERVER_BUILD_INDEX_ERROR, stat.ToString());
......@@ -919,7 +920,6 @@ DescribeIndexTask::OnExecute() {
index_param_->mutable_table_name()->set_table_name(table_name_);
index_param_->mutable_index()->set_index_type(index.engine_type_);
index_param_->mutable_index()->set_nlist(index.nlist_);
index_param_->mutable_index()->set_metric_type(index.metric_type_);
rc.ElapseFromBegin("totally cost");
} catch (std::exception &ex) {
......
......@@ -216,7 +216,7 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) {
ASSERT_TRUE(status.ok());
}
TEST_F(MemManagerTest, SERIAL_INSERT_SEARCH_TEST) {
TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
......@@ -262,7 +262,7 @@ TEST_F(MemManagerTest, SERIAL_INSERT_SEARCH_TEST) {
}
}
TEST_F(MemManagerTest, INSERT_TEST) {
TEST_F(MemManagerTest2, INSERT_TEST) {
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
......@@ -288,7 +288,7 @@ TEST_F(MemManagerTest, INSERT_TEST) {
LOG(DEBUG) << "total_time spent in INSERT_TEST (ms) : " << total_time;
}
TEST_F(MemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) {
TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
......@@ -359,7 +359,7 @@ TEST_F(MemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) {
search.join();
};
TEST_F(MemManagerTest, VECTOR_IDS_TEST) {
TEST_F(MemManagerTest2, VECTOR_IDS_TEST) {
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
......
......@@ -51,6 +51,10 @@ void BaseTest::InitLog() {
el::Loggers::reconfigureLogger("default", defaultConf);
}
void BaseTest::SetUp() {
InitLog();
}
engine::Options BaseTest::GetOptions() {
auto options = engine::OptionsFactory::Build();
options.meta.path = "/tmp/milvus_test";
......@@ -60,7 +64,7 @@ engine::Options BaseTest::GetOptions() {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void DBTest::SetUp() {
InitLog();
BaseTest::SetUp();
server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE);
config.AddSequenceItem(server::CONFIG_GPU_IDS, "0");
......@@ -104,7 +108,8 @@ engine::Options DBTest2::GetOptions() {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void MetaTest::SetUp() {
InitLog();
BaseTest::SetUp();
impl_ = engine::DBMetaImplFactory::Build();
}
......@@ -127,7 +132,7 @@ engine::Options MySqlDBTest::GetOptions() {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void MySqlMetaTest::SetUp() {
InitLog();
BaseTest::SetUp();
engine::DBMetaOptions options = GetOptions().meta;
int mode = engine::Options::MODE::SINGLE;
......
......@@ -37,6 +37,8 @@ void ASSERT_STATS(zilliz::milvus::engine::Status &stat);
class BaseTest : public ::testing::Test {
protected:
void InitLog();
virtual void SetUp() override;
virtual zilliz::milvus::engine::Options GetOptions();
};
......@@ -82,5 +84,9 @@ class MySqlMetaTest : public BaseTest {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class MemManagerTest : public DBTest {
class MemManagerTest : public BaseTest {
};
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class MemManagerTest2 : public DBTest {
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册