提交 457688d7 编写于 作者: X xj.lin

add search support


Former-commit-id: 4aae0710844eff71fb83f3d73f3fa7463b00e99b
上级 d7c7720e
......@@ -7,6 +7,8 @@
#include <faiss/index_io.h>
#include <faiss/AutoTune.h>
#include <wrapper/IndexBuilder.h>
#include <cstring>
#include <wrapper/Topk.h>
#include "DBImpl.h"
#include "DBMetaImpl.h"
#include "Env.h"
......@@ -53,9 +55,82 @@ Status DBImpl::add_vectors(const std::string& group_id_,
}
}
Status DBImpl::search(const std::string& group_id, size_t k, size_t nq,
const float* vectors, QueryResults& results) {
// PXU TODO
Status DBImpl::search(const std::string &group_id, size_t k, size_t nq,
const float *vectors, QueryResults &results) {
meta::DatePartionedGroupFilesSchema files;
std::vector<meta::DateT> partition;
auto status = _pMeta->files_to_search(group_id, partition, files);
if (!status.ok()) { return status; }
// TODO: optimized
meta::GroupFilesSchema index_files;
meta::GroupFilesSchema raw_files;
for (auto &day_files : files) {
for (auto &file : day_files.second) {
file.file_type == meta::GroupFileSchema::RAW ?
raw_files.push_back(file) :
index_files.push_back(file);
}
}
int dim = raw_files[0].dimension;
// merge raw files
faiss::Index *index(faiss::index_factory(dim, "IDMap,Flat"));
for (auto &file : raw_files) {
auto file_index = dynamic_cast<faiss::IndexIDMap *>(faiss::read_index(file.location.c_str()));
index->add_with_ids(file_index->ntotal, dynamic_cast<faiss::IndexFlat *>(file_index->index)->xb.data(),
file_index->id_map.data());
}
float *xb = dynamic_cast<faiss::IndexFlat *>(index)->xb.data();
int64_t *ids = dynamic_cast<faiss::IndexIDMap *>(index)->id_map.data();
long totoal = index->ntotal;
std::vector<float> distence;
std::vector<long> result_ids;
{
// allocate memory
float *output_distence;
long *output_ids;
output_distence = (float *) malloc(k * sizeof(float));
output_ids = (long *) malloc(k * sizeof(long));
// build and search in raw file
// TODO: HardCode
auto opd = std::make_shared<Operand>();
opd->index_type = "IDMap,Flat";
IndexBuilderPtr builder = GetIndexBuilder(opd);
auto index = builder->build_all(totoal, xb, ids);
index->search(nq, vectors, k, output_distence, output_ids);
distence.insert(distence.begin(), output_distence, output_distence + k);
result_ids.insert(result_ids.begin(), output_ids, output_ids + k);
memset(output_distence, 0, k * sizeof(float));
memset(output_ids, 0, k * sizeof(long));
// search in index file
for (auto &file : index_files) {
auto index = read_index(file.location.c_str());
index->search(nq, vectors, k, output_distence, output_ids);
distence.insert(distence.begin(), output_distence, output_distence + k);
result_ids.insert(result_ids.begin(), output_ids, output_ids + k);
memset(output_distence, 0, k * sizeof(float));
memset(output_ids, 0, k * sizeof(long));
}
// TopK
TopK(distence.data(), distence.size(), k, output_distence, output_ids);
distence.clear();
result_ids.clear();
distence.insert(distence.begin(), output_distence, output_distence + k);
result_ids.insert(result_ids.begin(), output_ids, output_ids + k);
// free
free(output_distence);
free(output_ids);
}
return Status::OK();
}
......
......@@ -226,6 +226,48 @@ Status DBMetaImpl::files_to_index(GroupFilesSchema& files) {
return Status::OK();
}
Status DBMetaImpl::files_to_search(const std::string &group_id,
std::vector<DateT> partition,
DatePartionedGroupFilesSchema &files) {
// TODO: support data partition
files.clear();
auto selected = ConnectorPtr->select(columns(&GroupFileSchema::id,
&GroupFileSchema::group_id,
&GroupFileSchema::file_id,
&GroupFileSchema::file_type,
&GroupFileSchema::rows,
&GroupFileSchema::date),
where(c(&GroupFileSchema::group_id) == group_id and
(c(&GroupFileSchema::file_type) == (int) GroupFileSchema::RAW or
c(&GroupFileSchema::file_type) == (int) GroupFileSchema::INDEX)));
GroupSchema group_info;
group_info.group_id = group_id;
auto status = get_group_no_lock(group_info);
if (!status.ok()) {
return status;
}
for (auto& file : selected) {
GroupFileSchema group_file;
group_file.id = std::get<0>(file);
group_file.group_id = std::get<1>(file);
group_file.file_id = std::get<2>(file);
group_file.file_type = std::get<3>(file);
group_file.rows = std::get<4>(file);
group_file.date = std::get<5>(file);
group_file.dimension = group_info.dimension;
GetGroupFilePath(group_file);
auto dateItr = files.find(group_file.date);
if (dateItr == files.end()) {
files[group_file.date] = GroupFilesSchema();
}
files[group_file.date].push_back(group_file);
}
return Status::OK();
}
Status DBMetaImpl::files_to_merge(const std::string& group_id,
DatePartionedGroupFilesSchema& files) {
files.clear();
......
......@@ -38,6 +38,10 @@ public:
virtual Status files_to_merge(const std::string& group_id,
DatePartionedGroupFilesSchema& files) override;
virtual Status files_to_search(const std::string& group_id,
std::vector<DateT> partition,
DatePartionedGroupFilesSchema& files) override;
virtual Status files_to_index(GroupFilesSchema&) override;
virtual Status cleanup() override;
......
......@@ -70,6 +70,10 @@ public:
virtual Status update_files(const GroupFilesSchema& files) = 0;
virtual Status files_to_search(const std::string& group_id,
std::vector<DateT> partition,
DatePartionedGroupFilesSchema& files) = 0;
virtual Status files_to_merge(const std::string& group_id,
DatePartionedGroupFilesSchema& files) = 0;
......
......@@ -11,6 +11,7 @@
#endif
#include "Index.h"
#include "faiss/index_io.h"
namespace zilliz {
namespace vecwise {
......@@ -66,6 +67,12 @@ void write_index(const Index_ptr &index, const std::string &file_name) {
write_index(index->index_.get(), file_name.c_str());
}
Index_ptr read_index(const std::string &file_name) {
std::shared_ptr<faiss::Index> raw_index = nullptr;
raw_index.reset(faiss::read_index(file_name.c_str()));
return std::make_shared<Index>(raw_index);
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册