提交 5b23b442 编写于 作者: J jinhai

Merge branch 'origin/branch-0.3.1' into 'branch-0.3.1'

MS-133 Change score to distance

See merge request megasearch/vecwise_engine!134

Former-commit-id: 1e071ddf6096234d368c88dfabcb28a6e371aaee
......@@ -9,7 +9,7 @@
#include "EngineFactory.h"
#include "metrics/Metrics.h"
#include "scheduler/TaskScheduler.h"
#include "scheduler/context/SearchContext.h"
#include "scheduler/context/DeleteContext.h"
#include "utils/TimeRecorder.h"
......@@ -27,9 +27,9 @@ namespace engine {
namespace {
static constexpr uint64_t METRIC_ACTION_INTERVAL = 1;
static constexpr uint64_t COMPACT_ACTION_INTERVAL = 1;
static constexpr uint64_t INDEX_ACTION_INTERVAL = 1;
constexpr uint64_t METRIC_ACTION_INTERVAL = 1;
constexpr uint64_t COMPACT_ACTION_INTERVAL = 1;
constexpr uint64_t INDEX_ACTION_INTERVAL = 1;
void CollectInsertMetrics(double total_time, size_t n, bool succeed) {
double avg_time = total_time / n;
......@@ -76,56 +76,6 @@ void CollectFileMetrics(int file_type, size_t file_size, double total_time) {
}
}
}
void CalcScore(uint64_t vector_count,
const float *vectors_data,
uint64_t dimension,
const SearchContext::ResultSet &result_src,
SearchContext::ResultSet &result_target) {
result_target.clear();
if(result_src.empty()){
return;
}
server::TimeRecorder rc("Calculate Score");
int vec_index = 0;
for(auto& result : result_src) {
const float * vec_data = vectors_data + vec_index*dimension;
double vec_len = 0;
for(uint64_t i = 0; i < dimension; i++) {
vec_len += vec_data[i]*vec_data[i];
}
vec_index++;
double max_score = 0.0;
for(auto& pair : result) {
if(max_score < pair.second) {
max_score = pair.second;
}
}
//makesure socre is less than 100
if(max_score > vec_len) {
vec_len = max_score;
}
//avoid divided by zero
static constexpr double TOLERANCE = std::numeric_limits<float>::epsilon();
if(vec_len < TOLERANCE) {
vec_len = TOLERANCE;
}
SearchContext::Id2ScoreMap score_array;
double vec_len_inverse = 1.0/vec_len;
for(auto& pair : result) {
score_array.push_back(std::make_pair(pair.first, (1 - pair.second*vec_len_inverse)*100.0));
}
result_target.emplace_back(score_array);
}
rc.Elapse("totally cost");
}
}
......@@ -232,7 +182,7 @@ Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>
meta::TableFileSchema table_file;
table_file.table_id_ = table_id;
std::string::size_type sz;
ids.push_back(std::stol(id, &sz));
ids.push_back(std::stoul(id, &sz));
}
meta::TableFilesSchema files_array;
......@@ -380,10 +330,6 @@ Status DBImpl::QuerySync(const std::string& table_id, uint64_t k, uint64_t nq,
return Status::NotFound("Group " + table_id + ", search result not found!");
}
QueryResults temp_results;
CalcScore(nq, vectors, dim, results, temp_results);
results.swap(temp_results);
return Status::OK();
}
......@@ -405,13 +351,8 @@ Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSch
context->WaitResult();
//step 3: construct results, calculate score between 0 ~ 100
auto& context_result = context->GetResult();
meta::TableSchema table_schema;
table_schema.table_id_ = table_id;
meta_ptr_->DescribeTable(table_schema);
CalcScore(context->nq(), context->vectors(), table_schema.dimension_, context_result, results);
//step 3: construct results
results = context->GetResult();
return Status::OK();
}
......@@ -575,7 +516,7 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
Status status;
for (auto table_id : table_ids) {
for (auto& table_id : table_ids) {
status = BackgroundMergeFiles(table_id);
if (!status.ok()) {
bg_error_ = status;
......
......@@ -17,6 +17,8 @@
#include <thread>
#include <list>
#include <set>
#include "scheduler/context/SearchContext.h"
namespace zilliz {
namespace milvus {
......@@ -25,49 +27,80 @@ namespace engine {
class Env;
namespace meta {
class Meta;
class Meta;
}
class DBImpl : public DB {
public:
public:
using MetaPtr = meta::Meta::Ptr;
using MemManagerPtr = typename MemManager::Ptr;
DBImpl(const Options& options);
explicit DBImpl(const Options &options);
Status
CreateTable(meta::TableSchema &table_schema) override;
Status
DeleteTable(const std::string &table_id, const meta::DatesT &dates) override;
Status
DescribeTable(meta::TableSchema &table_schema) override;
virtual Status CreateTable(meta::TableSchema& table_schema) override;
virtual Status DeleteTable(const std::string& table_id, const meta::DatesT& dates) override;
virtual Status DescribeTable(meta::TableSchema& table_schema) override;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) override;
virtual Status AllTables(std::vector<meta::TableSchema>& table_schema_array) override;
virtual Status GetTableRowCount(const std::string& table_id, uint64_t& row_count) override;
Status
HasTable(const std::string &table_id, bool &has_or_not) override;
virtual Status InsertVectors(const std::string& table_id,
uint64_t n, const float* vectors, IDNumbers& vector_ids) override;
Status
AllTables(std::vector<meta::TableSchema> &table_schema_array) override;
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq,
const float* vectors, QueryResults& results) override;
Status
GetTableRowCount(const std::string &table_id, uint64_t &row_count) override;
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq,
const float* vectors, const meta::DatesT& dates, QueryResults& results) override;
Status
InsertVectors(const std::string &table_id, uint64_t n, const float *vectors, IDNumbers &vector_ids) override;
virtual Status Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors,
const meta::DatesT& dates, QueryResults& results) override;
Status
Query(const std::string &table_id, uint64_t k, uint64_t nq, const float *vectors, QueryResults &results) override;
virtual Status DropAll() override;
Status
Query(const std::string &table_id,
uint64_t k,
uint64_t nq,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results) override;
virtual Status Size(uint64_t& result) override;
Status
Query(const std::string &table_id,
const std::vector<std::string> &file_ids,
uint64_t k,
uint64_t nq,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results) override;
virtual ~DBImpl();
Status DropAll() override;
private:
Status QuerySync(const std::string& table_id, uint64_t k, uint64_t nq,
const float* vectors, const meta::DatesT& dates, QueryResults& results);
Status Size(uint64_t &result) override;
Status QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files,
uint64_t k, uint64_t nq, const float* vectors,
const meta::DatesT& dates, QueryResults& results);
~DBImpl() override;
private:
Status
QuerySync(const std::string &table_id,
uint64_t k,
uint64_t nq,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results);
Status
QueryAsync(const std::string &table_id,
const meta::TableFilesSchema &files,
uint64_t k,
uint64_t nq,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results);
void StartTimerTasks();
......@@ -76,15 +109,19 @@ private:
void StartMetricTask();
void StartCompactionTask();
Status MergeFiles(const std::string& table_id,
const meta::DateT& date,
const meta::TableFilesSchema& files);
Status BackgroundMergeFiles(const std::string& table_id);
Status MergeFiles(const std::string &table_id,
const meta::DateT &date,
const meta::TableFilesSchema &files);
Status BackgroundMergeFiles(const std::string &table_id);
void BackgroundCompaction(std::set<std::string> table_ids);
void StartBuildIndexTask();
void BackgroundBuildIndex();
Status BuildIndex(const meta::TableFileSchema&);
Status
BuildIndex(const meta::TableFileSchema &);
private:
const Options options_;
......
......@@ -15,33 +15,46 @@ namespace engine {
class Status {
public:
Status() noexcept : state_(nullptr) {}
~Status() { delete[] state_; }
Status(const Status &rhs);
Status &operator=(const Status &rhs);
Status &
operator=(const Status &rhs);
Status(Status &&rhs) noexcept : state_(rhs.state_) { rhs.state_ = nullptr; }
Status &operator=(Status &&rhs_) noexcept;
static Status OK() { return Status(); }
static Status NotFound(const std::string &msg, const std::string &msg2 = "") {
Status &
operator=(Status &&rhs_) noexcept;
static Status
OK() { return Status(); }
static Status
NotFound(const std::string &msg, const std::string &msg2 = "") {
return Status(kNotFound, msg, msg2);
}
static Status Error(const std::string &msg, const std::string &msg2 = "") {
static Status
Error(const std::string &msg, const std::string &msg2 = "") {
return Status(kError, msg, msg2);
}
static Status InvalidDBPath(const std::string &msg, const std::string &msg2 = "") {
static Status
InvalidDBPath(const std::string &msg, const std::string &msg2 = "") {
return Status(kInvalidDBPath, msg, msg2);
}
static Status GroupError(const std::string &msg, const std::string &msg2 = "") {
static Status
GroupError(const std::string &msg, const std::string &msg2 = "") {
return Status(kGroupError, msg, msg2);
}
static Status DBTransactionError(const std::string &msg, const std::string &msg2 = "") {
static Status
DBTransactionError(const std::string &msg, const std::string &msg2 = "") {
return Status(kDBTransactionError, msg, msg2);
}
static Status AlreadyExist(const std::string &msg, const std::string &msg2 = "") {
static Status
AlreadyExist(const std::string &msg, const std::string &msg2 = "") {
return Status(kAlreadyExist, msg, msg2);
}
......
......@@ -56,7 +56,7 @@ namespace {
<< std::to_string(result.query_result_arrays.size())
<< " search result:" << std::endl;
for(auto& item : result.query_result_arrays) {
std::cout << "\t" << std::to_string(item.id) << "\tscore:" << std::to_string(item.score);
std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance);
std::cout << std::endl;
}
}
......
......@@ -59,7 +59,7 @@ struct RowRecord {
*/
struct QueryResult {
int64_t id; ///< Output result
double score; ///< Vector similarity score: 0 ~ 100
double distance; ///< Vector similarity distance
};
/**
......
......@@ -203,7 +203,7 @@ ClientProxy::SearchVector(const std::string &table_name,
for(auto& thrift_query_result : thrift_topk_result.query_result_arrays) {
QueryResult query_result;
query_result.id = thrift_query_result.id;
query_result.score = thrift_query_result.score;
query_result.distance = thrift_query_result.distance;
result.query_result_arrays.emplace_back(query_result);
}
......
......@@ -514,7 +514,7 @@ ServerError SearchVectorTask::OnExecute() {
for(auto& pair : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(pair.first);
thrift_result.__set_score(pair.second);
thrift_result.__set_distance(pair.second);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
......
......@@ -564,8 +564,8 @@ void QueryResult::__set_id(const int64_t val) {
this->id = val;
}
void QueryResult::__set_score(const double val) {
this->score = val;
void QueryResult::__set_distance(const double val) {
this->distance = val;
}
std::ostream& operator<<(std::ostream& out, const QueryResult& obj)
{
......@@ -605,8 +605,8 @@ uint32_t QueryResult::read(::apache::thrift::protocol::TProtocol* iprot) {
break;
case 2:
if (ftype == ::apache::thrift::protocol::T_DOUBLE) {
xfer += iprot->readDouble(this->score);
this->__isset.score = true;
xfer += iprot->readDouble(this->distance);
this->__isset.distance = true;
} else {
xfer += iprot->skip(ftype);
}
......@@ -632,8 +632,8 @@ uint32_t QueryResult::write(::apache::thrift::protocol::TProtocol* oprot) const
xfer += oprot->writeI64(this->id);
xfer += oprot->writeFieldEnd();
xfer += oprot->writeFieldBegin("score", ::apache::thrift::protocol::T_DOUBLE, 2);
xfer += oprot->writeDouble(this->score);
xfer += oprot->writeFieldBegin("distance", ::apache::thrift::protocol::T_DOUBLE, 2);
xfer += oprot->writeDouble(this->distance);
xfer += oprot->writeFieldEnd();
xfer += oprot->writeFieldStop();
......@@ -644,18 +644,18 @@ uint32_t QueryResult::write(::apache::thrift::protocol::TProtocol* oprot) const
void swap(QueryResult &a, QueryResult &b) {
using ::std::swap;
swap(a.id, b.id);
swap(a.score, b.score);
swap(a.distance, b.distance);
swap(a.__isset, b.__isset);
}
QueryResult::QueryResult(const QueryResult& other9) {
id = other9.id;
score = other9.score;
distance = other9.distance;
__isset = other9.__isset;
}
QueryResult& QueryResult::operator=(const QueryResult& other10) {
id = other10.id;
score = other10.score;
distance = other10.distance;
__isset = other10.__isset;
return *this;
}
......@@ -663,7 +663,7 @@ void QueryResult::printTo(std::ostream& out) const {
using ::apache::thrift::to_string;
out << "QueryResult(";
out << "id=" << to_string(id);
out << ", " << "score=" << to_string(score);
out << ", " << "distance=" << to_string(distance);
out << ")";
}
......
......@@ -256,9 +256,9 @@ void swap(RowRecord &a, RowRecord &b);
std::ostream& operator<<(std::ostream& out, const RowRecord& obj);
typedef struct _QueryResult__isset {
_QueryResult__isset() : id(false), score(false) {}
_QueryResult__isset() : id(false), distance(false) {}
bool id :1;
bool score :1;
bool distance :1;
} _QueryResult__isset;
class QueryResult : public virtual ::apache::thrift::TBase {
......@@ -266,24 +266,24 @@ class QueryResult : public virtual ::apache::thrift::TBase {
QueryResult(const QueryResult&);
QueryResult& operator=(const QueryResult&);
QueryResult() : id(0), score(0) {
QueryResult() : id(0), distance(0) {
}
virtual ~QueryResult() throw();
int64_t id;
double score;
double distance;
_QueryResult__isset __isset;
void __set_id(const int64_t val);
void __set_score(const double val);
void __set_distance(const double val);
bool operator == (const QueryResult & rhs) const
{
if (!(id == rhs.id))
return false;
if (!(score == rhs.score))
if (!(distance == rhs.distance))
return false;
return true;
}
......
......@@ -73,7 +73,7 @@ struct RowRecord {
*/
struct QueryResult {
1: i64 id; ///< Output result
2: double score; ///< Vector similarity score: 0 ~ 100
2: double distance; ///< Vector similarity distance
}
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册