未验证 提交 66add7b9 编写于 作者: G groot 提交者: GitHub

remove old api (#2833)

* remove old api
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* refine insert code
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* refine code
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>
上级 dc317425
......@@ -1056,47 +1056,19 @@ SSDBImpl::ExecWalRecord(const wal::MXLogRecord& record) {
return status;
}
status =
mem_mgr_->InsertEntities(collection_id, partition_id, record.length, record.ids,
(record.data_size / record.length / sizeof(float)), (const float*)record.data,
record.attr_nbytes, record.attr_data_size, record.attr_data, record.lsn);
force_flush_if_mem_full();
// metrics
milvus::server::CollectInsertMetrics metrics(record.length, status);
break;
}
case wal::MXLogType::InsertBinary: {
int64_t collection_id = 0, partition_id = 0;
auto status = get_collection_partition_id(record, collection_id, partition_id);
if (!status.ok()) {
LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "insert", 0)
<< "Get collection/partition id fail: " << status.message();
return status;
}
status = mem_mgr_->InsertVectors(collection_id, partition_id, record.length, record.ids,
(record.data_size / record.length / sizeof(uint8_t)),
(const u_int8_t*)record.data, record.lsn);
force_flush_if_mem_full();
// metrics
milvus::server::CollectInsertMetrics metrics(record.length, status);
break;
}
case wal::MXLogType::InsertVector: {
int64_t collection_id = 0, partition_id = 0;
auto status = get_collection_partition_id(record, collection_id, partition_id);
if (!status.ok()) {
LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "insert", 0)
<< "Get collection/partition id fail: " << status.message();
return status;
}
status = mem_mgr_->InsertVectors(collection_id, partition_id, record.length, record.ids,
(record.data_size / record.length / sizeof(float)),
(const float*)record.data, record.lsn);
// construct chunk data
DataChunkPtr chunk = std::make_shared<DataChunk>();
chunk->count_ = record.length;
chunk->fields_data_ = record.attr_data;
std::vector<uint8_t> uid_data;
uid_data.resize(record.length * sizeof(int64_t));
memcpy(uid_data.data(), record.ids, record.length * sizeof(int64_t));
std::vector<uint8_t> vector_data;
vector_data.resize(record.data_size);
memcpy(vector_data.data(), record.data, record.data_size);
chunk->fields_data_.insert(std::make_pair(VECTOR_FIELD, vector_data));
status = mem_mgr_->InsertEntities(collection_id, partition_id, chunk, record.lsn);
force_flush_if_mem_full();
// metrics
......@@ -1113,12 +1085,12 @@ SSDBImpl::ExecWalRecord(const wal::MXLogRecord& record) {
}
if (record.length == 1) {
status = mem_mgr_->DeleteVector(ss->GetCollectionId(), *record.ids, record.lsn);
status = mem_mgr_->DeleteEntity(ss->GetCollectionId(), *record.ids, record.lsn);
if (!status.ok()) {
return status;
}
} else {
status = mem_mgr_->DeleteVectors(ss->GetCollectionId(), record.length, record.ids, record.lsn);
status = mem_mgr_->DeleteEntities(ss->GetCollectionId(), record.length, record.ids, record.lsn);
if (!status.ok()) {
return status;
}
......
......@@ -33,7 +33,7 @@ SSMemCollection::SSMemCollection(int64_t collection_id, int64_t partition_id, co
}
Status
SSMemCollection::Add(const SSVectorSourcePtr& source) {
SSMemCollection::Add(const milvus::engine::SSVectorSourcePtr& source) {
while (!source->AllAdded()) {
SSMemSegmentPtr current_mem_segment;
if (!mem_segment_list_.empty()) {
......@@ -60,34 +60,6 @@ SSMemCollection::Add(const SSVectorSourcePtr& source) {
return Status::OK();
}
Status
SSMemCollection::AddEntities(const milvus::engine::SSVectorSourcePtr& source) {
while (!source->AllAdded()) {
SSMemSegmentPtr current_mem_segment;
if (!mem_segment_list_.empty()) {
current_mem_segment = mem_segment_list_.back();
}
Status status;
if (mem_segment_list_.empty() || current_mem_segment->IsFull()) {
SSMemSegmentPtr new_mem_segment = std::make_shared<SSMemSegment>(collection_id_, partition_id_, options_);
status = new_mem_segment->AddEntities(source);
if (status.ok()) {
mem_segment_list_.emplace_back(new_mem_segment);
}
} else {
status = current_mem_segment->AddEntities(source);
}
if (!status.ok()) {
std::string err_msg = "Insert failed: " + status.ToString();
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << err_msg;
return Status(DB_ERROR, err_msg);
}
}
return Status::OK();
}
Status
SSMemCollection::Delete(segment::doc_id_t doc_id) {
// Locate which collection file the doc id lands in
......
......@@ -35,9 +35,6 @@ class SSMemCollection : public server::CacheConfigHandler {
Status
Add(const SSVectorSourcePtr& source);
Status
AddEntities(const SSVectorSourcePtr& source);
Status
Delete(segment::doc_id_t doc_id);
......
......@@ -23,27 +23,27 @@
namespace milvus {
namespace engine {
class SSMemManager {
public:
virtual Status
InsertVectors(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const float* vectors, uint64_t lsn) = 0;
extern const char* ENTITY_ID_FIELD;
extern const char* VECTOR_DIMENSION_PARAM;
extern const char* VECTOR_FIELD;
virtual Status
InsertVectors(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const uint8_t* vectors, uint64_t lsn) = 0;
struct DataChunk {
uint64_t count_ = 0;
std::unordered_map<std::string, std::vector<uint8_t>> fields_data_;
};
using DataChunkPtr = std::shared_ptr<DataChunk>;
class SSMemManager {
public:
virtual Status
InsertEntities(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const float* vectors, const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) = 0;
InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) = 0;
virtual Status
DeleteVector(int64_t collection_id, IDNumber vector_id, uint64_t lsn) = 0;
DeleteEntity(int64_t collection_id, IDNumber vector_id, uint64_t lsn) = 0;
virtual Status
DeleteVectors(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) = 0;
DeleteEntities(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) = 0;
virtual Status
Flush(int64_t collection_id) = 0;
......
......@@ -16,11 +16,16 @@
#include "SSVectorSource.h"
#include "db/Constants.h"
#include "db/snapshot/Snapshots.h"
#include "utils/Log.h"
namespace milvus {
namespace engine {
const char* ENTITY_ID_FIELD = "id"; // hard code
const char* VECTOR_DIMENSION_PARAM = "dimension"; // hard code
const char* VECTOR_FIELD = "vector"; // hard code
SSMemCollectionPtr
SSMemManagerImpl::GetMemByTable(int64_t collection_id, int64_t partition_id) {
auto mem_collection = mem_map_.find(collection_id);
......@@ -49,65 +54,105 @@ SSMemManagerImpl::GetMemByTable(int64_t collection_id) {
}
Status
SSMemManagerImpl::InsertVectors(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids,
int64_t dim, const float* vectors, uint64_t lsn) {
VectorsData vectors_data;
vectors_data.vector_count_ = length;
vectors_data.float_data_.resize(length * dim);
memcpy(vectors_data.float_data_.data(), vectors, length * dim * sizeof(float));
vectors_data.id_array_.resize(length);
memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber));
SSVectorSourcePtr source = std::make_shared<SSVectorSource>(vectors_data);
std::unique_lock<std::mutex> lock(mutex_);
return InsertVectorsNoLock(collection_id, partition_id, source, lsn);
}
Status
SSMemManagerImpl::InsertVectors(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids,
int64_t dim, const uint8_t* vectors, uint64_t lsn) {
VectorsData vectors_data;
vectors_data.vector_count_ = length;
vectors_data.binary_data_.resize(length * dim);
memcpy(vectors_data.binary_data_.data(), vectors, length * dim * sizeof(uint8_t));
vectors_data.id_array_.resize(length);
memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber));
SSVectorSourcePtr source = std::make_shared<SSVectorSource>(vectors_data);
SSMemManagerImpl::InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) {
auto status = ValidateChunk(collection_id, partition_id, chunk);
if (!status.ok()) {
return status;
}
SSVectorSourcePtr source = std::make_shared<SSVectorSource>(chunk);
std::unique_lock<std::mutex> lock(mutex_);
return InsertVectorsNoLock(collection_id, partition_id, source, lsn);
return InsertEntitiesNoLock(collection_id, partition_id, source, lsn);
}
Status
SSMemManagerImpl::InsertEntities(int64_t collection_id, int64_t partition_id, int64_t length,
const IDNumber* vector_ids, int64_t dim, const float* vectors,
const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) {
VectorsData vectors_data;
vectors_data.vector_count_ = length;
vectors_data.float_data_.resize(length * dim);
memcpy(vectors_data.float_data_.data(), vectors, length * dim * sizeof(float));
vectors_data.id_array_.resize(length);
memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber));
SSVectorSourcePtr source = std::make_shared<SSVectorSource>(vectors_data, attr_nbytes, attr_size, attr_data);
SSMemManagerImpl::ValidateChunk(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk) {
if (chunk == nullptr) {
return Status(DB_ERROR, "Null chunk pointer");
}
std::unique_lock<std::mutex> lock(mutex_);
snapshot::ScopedSnapshotT ss;
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id);
if (!status.ok()) {
std::string err_msg = "Could not get snapshot: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
return status;
}
return InsertEntitiesNoLock(collection_id, partition_id, source, lsn);
}
std::vector<std::string> field_names = ss->GetFieldNames();
for (auto& name : field_names) {
auto iter = chunk->fields_data_.find(name);
if (iter == chunk->fields_data_.end()) {
std::string err_msg = "Missed chunk field: " + name;
LOG_ENGINE_ERROR_ << err_msg;
return Status(DB_ERROR, err_msg);
}
Status
SSMemManagerImpl::InsertVectorsNoLock(int64_t collection_id, int64_t partition_id, const SSVectorSourcePtr& source,
uint64_t lsn) {
SSMemCollectionPtr mem = GetMemByTable(collection_id, partition_id);
mem->SetLSN(lsn);
size_t data_size = iter->second.size();
snapshot::FieldPtr field = ss->GetField(name);
meta::hybrid::DataType ftype = static_cast<meta::hybrid::DataType>(field->GetFtype());
std::string err_msg = "Illegal data size for chunk field: ";
switch (ftype) {
case meta::hybrid::DataType::BOOL:
if (data_size != chunk->count_ * sizeof(bool)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::DOUBLE:
if (data_size != chunk->count_ * sizeof(double)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::FLOAT:
if (data_size != chunk->count_ * sizeof(float)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::INT8:
if (data_size != chunk->count_ * sizeof(uint8_t)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::INT16:
if (data_size != chunk->count_ * sizeof(uint16_t)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::INT32:
if (data_size != chunk->count_ * sizeof(uint32_t)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::UID:
case meta::hybrid::DataType::INT64:
if (data_size != chunk->count_ * sizeof(uint64_t)) {
return Status(DB_ERROR, err_msg + name);
}
break;
case meta::hybrid::DataType::VECTOR:
case meta::hybrid::DataType::VECTOR_FLOAT:
case meta::hybrid::DataType::VECTOR_BINARY: {
json params = field->GetParams();
if (params.find(VECTOR_DIMENSION_PARAM) == params.end()) {
std::string msg = "Vector field params must contain: dimension";
LOG_SERVER_ERROR_ << msg;
return Status(DB_ERROR, msg);
}
int64_t dimension = params[VECTOR_DIMENSION_PARAM];
int64_t row_size =
(ftype == meta::hybrid::DataType::VECTOR_BINARY) ? dimension / 8 : dimension * sizeof(float);
if (data_size != chunk->count_ * row_size) {
return Status(DB_ERROR, err_msg + name);
}
break;
}
}
}
auto status = mem->Add(source);
return status;
return Status::OK();
}
Status
......@@ -116,12 +161,12 @@ SSMemManagerImpl::InsertEntitiesNoLock(int64_t collection_id, int64_t partition_
SSMemCollectionPtr mem = GetMemByTable(collection_id, partition_id);
mem->SetLSN(lsn);
auto status = mem->AddEntities(source);
auto status = mem->Add(source);
return status;
}
Status
SSMemManagerImpl::DeleteVector(int64_t collection_id, IDNumber vector_id, uint64_t lsn) {
SSMemManagerImpl::DeleteEntity(int64_t collection_id, IDNumber vector_id, uint64_t lsn) {
std::unique_lock<std::mutex> lock(mutex_);
std::vector<SSMemCollectionPtr> mems = GetMemByTable(collection_id);
......@@ -137,7 +182,7 @@ SSMemManagerImpl::DeleteVector(int64_t collection_id, IDNumber vector_id, uint64
}
Status
SSMemManagerImpl::DeleteVectors(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) {
SSMemManagerImpl::DeleteEntities(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) {
std::unique_lock<std::mutex> lock(mutex_);
std::vector<SSMemCollectionPtr> mems = GetMemByTable(collection_id);
......
......@@ -42,24 +42,13 @@ class SSMemManagerImpl : public SSMemManager, public server::CacheConfigHandler
}
Status
InsertVectors(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const float* vectors, uint64_t lsn) override;
InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) override;
Status
InsertVectors(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const uint8_t* vectors, uint64_t lsn) override;
DeleteEntity(int64_t collection_id, IDNumber vector_id, uint64_t lsn) override;
Status
InsertEntities(int64_t collection_id, int64_t partition_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const float* vectors, const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) override;
Status
DeleteVector(int64_t collection_id, IDNumber vector_id, uint64_t lsn) override;
Status
DeleteVectors(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) override;
DeleteEntities(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) override;
Status
Flush(int64_t collection_id) override;
......@@ -94,7 +83,7 @@ class SSMemManagerImpl : public SSMemManager, public server::CacheConfigHandler
GetMemByTable(int64_t collection_id);
Status
InsertVectorsNoLock(int64_t collection_id, int64_t partition_id, const SSVectorSourcePtr& source, uint64_t lsn);
ValidateChunk(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk);
Status
InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id, const SSVectorSourcePtr& source, uint64_t lsn);
......
......@@ -20,6 +20,7 @@
#include "db/Constants.h"
#include "db/Utils.h"
#include "db/engine/EngineFactory.h"
#include "db/meta/MetaTypes.h"
#include "db/snapshot/Operations.h"
#include "db/snapshot/Snapshots.h"
#include "metrics/Metrics.h"
......@@ -66,73 +67,83 @@ SSMemSegment::CreateSegment() {
return status;
}
int64_t
SSMemSegment::GetDimension() {
Status
SSMemSegment::GetSingleEntitySize(size_t& single_size) {
snapshot::ScopedSnapshotT ss;
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id_);
if (!status.ok()) {
std::string err_msg = "SSMemSegment::GetDimension failed: " + status.ToString();
std::string err_msg = "SSMemSegment::SingleEntitySize failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
return 0;
return status;
}
const std::string hard_code_vector_field = "vector";
const std::string hard_code_dimension = "dimension";
snapshot::FieldPtr field = ss->GetField(hard_code_vector_field);
json params = field->GetParams();
if (params.find(hard_code_dimension) == params.end()) {
std::string msg = "Vector field params must contain: dimension";
LOG_SERVER_ERROR_ << msg;
return 0;
single_size = 0;
std::vector<std::string> field_names = ss->GetFieldNames();
for (auto& name : field_names) {
snapshot::FieldPtr field = ss->GetField(name);
meta::hybrid::DataType ftype = static_cast<meta::hybrid::DataType>(field->GetFtype());
switch (ftype) {
case meta::hybrid::DataType::BOOL:
single_size += sizeof(bool);
break;
case meta::hybrid::DataType::DOUBLE:
single_size += sizeof(double);
break;
case meta::hybrid::DataType::FLOAT:
single_size += sizeof(float);
break;
case meta::hybrid::DataType::INT8:
single_size += sizeof(uint8_t);
break;
case meta::hybrid::DataType::INT16:
single_size += sizeof(uint16_t);
break;
case meta::hybrid::DataType::INT32:
single_size += sizeof(uint32_t);
break;
case meta::hybrid::DataType::UID:
case meta::hybrid::DataType::INT64:
single_size += sizeof(uint64_t);
break;
case meta::hybrid::DataType::VECTOR:
case meta::hybrid::DataType::VECTOR_FLOAT:
case meta::hybrid::DataType::VECTOR_BINARY: {
json params = field->GetParams();
if (params.find(VECTOR_DIMENSION_PARAM) == params.end()) {
std::string msg = "Vector field params must contain: dimension";
LOG_SERVER_ERROR_ << msg;
return Status(DB_ERROR, msg);
}
int64_t dimension = params[VECTOR_DIMENSION_PARAM];
if (ftype == meta::hybrid::DataType::VECTOR_BINARY) {
single_size += (dimension / 8);
} else {
single_size += (dimension * sizeof(float));
}
break;
}
}
}
int64_t dimension = params[hard_code_dimension];
return dimension;
return Status::OK();
}
Status
SSMemSegment::Add(const SSVectorSourcePtr& source) {
int64_t dimension = GetDimension();
if (dimension <= 0) {
std::string err_msg = "SSMemSegment::Add: table_file_schema dimension = " + std::to_string(dimension) +
", collection_id = " + std::to_string(collection_id_);
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << err_msg;
return Status(DB_ERROR, "Not able to create collection file");
}
size_t single_vector_mem_size = source->SingleVectorSize(dimension);
size_t mem_left = GetMemLeft();
if (mem_left >= single_vector_mem_size) {
size_t num_vectors_to_add = std::ceil(mem_left / single_vector_mem_size);
size_t num_vectors_added;
auto status =
source->Add(/*execution_engine_,*/ segment_writer_ptr_, dimension, num_vectors_to_add, num_vectors_added);
if (status.ok()) {
current_mem_ += (num_vectors_added * single_vector_mem_size);
}
size_t single_entity_mem_size = 0;
auto status = GetSingleEntitySize(single_entity_mem_size);
if (!status.ok()) {
return status;
}
return Status::OK();
}
Status
SSMemSegment::AddEntities(const SSVectorSourcePtr& source) {
int64_t dimension = GetDimension();
if (dimension <= 0) {
std::string err_msg = "SSMemSegment::Add: table_file_schema dimension = " + std::to_string(dimension) +
", collection_id = " + std::to_string(collection_id_);
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << err_msg;
return Status(DB_ERROR, "Not able to create collection file");
}
size_t single_entity_mem_size = source->SingleEntitySize(dimension);
size_t mem_left = GetMemLeft();
if (mem_left >= single_entity_mem_size) {
size_t num_entities_to_add = std::ceil(mem_left / single_entity_mem_size);
size_t num_entities_added;
auto status = source->AddEntities(segment_writer_ptr_, dimension, num_entities_to_add, num_entities_added);
auto status = source->Add(segment_writer_ptr_, num_entities_to_add, num_entities_added);
if (status.ok()) {
current_mem_ += (num_entities_added * single_entity_mem_size);
......@@ -206,8 +217,13 @@ SSMemSegment::GetMemLeft() {
bool
SSMemSegment::IsFull() {
size_t single_vector_mem_size = GetDimension() * FLOAT_TYPE_SIZE;
return (GetMemLeft() < single_vector_mem_size);
size_t single_entity_mem_size = 0;
auto status = GetSingleEntitySize(single_entity_mem_size);
if (!status.ok()) {
return true;
}
return (GetMemLeft() < single_entity_mem_size);
}
Status
......
......@@ -36,9 +36,6 @@ class SSMemSegment : public server::CacheConfigHandler {
Status
Add(const SSVectorSourcePtr& source);
Status
AddEntities(const SSVectorSourcePtr& source);
Status
Delete(segment::doc_id_t doc_id);
......@@ -68,8 +65,8 @@ class SSMemSegment : public server::CacheConfigHandler {
Status
CreateSegment();
int64_t
GetDimension();
Status
GetSingleEntitySize(size_t& single_size);
private:
int64_t collection_id_;
......
......@@ -23,157 +23,64 @@
namespace milvus {
namespace engine {
SSVectorSource::SSVectorSource(VectorsData vectors) : vectors_(std::move(vectors)) {
current_num_vectors_added = 0;
}
SSVectorSource::SSVectorSource(milvus::engine::VectorsData vectors,
const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data)
: vectors_(std::move(vectors)), attr_nbytes_(attr_nbytes), attr_size_(attr_size), attr_data_(attr_data) {
current_num_vectors_added = 0;
current_num_attrs_added = 0;
}
Status
SSVectorSource::Add(const segment::SegmentWriterPtr& segment_writer_ptr, int64_t dimension,
const size_t& num_vectors_to_add, size_t& num_vectors_added) {
uint64_t n = vectors_.vector_count_;
server::CollectAddMetrics metrics(n, dimension);
num_vectors_added =
current_num_vectors_added + num_vectors_to_add <= n ? num_vectors_to_add : n - current_num_vectors_added;
IDNumbers vector_ids_to_add;
if (vectors_.id_array_.empty()) {
SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance();
Status status = id_generator.GetNextIDNumbers(num_vectors_added, vector_ids_to_add);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << "Generate ids fail: " << status.message();
return status;
}
} else {
vector_ids_to_add.resize(num_vectors_added);
for (size_t pos = current_num_vectors_added; pos < current_num_vectors_added + num_vectors_added; pos++) {
vector_ids_to_add[pos - current_num_vectors_added] = vectors_.id_array_[pos];
}
}
Status status;
if (!vectors_.float_data_.empty()) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert float data into segment";
auto size = num_vectors_added * dimension * sizeof(float);
float* ptr = vectors_.float_data_.data() + current_num_vectors_added * dimension;
status = segment_writer_ptr->AddVectors("", (uint8_t*)ptr, size, vector_ids_to_add);
} else if (!vectors_.binary_data_.empty()) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert binary data into segment";
std::vector<uint8_t> vectors;
auto size = num_vectors_added * SingleVectorSize(dimension) * sizeof(uint8_t);
uint8_t* ptr = vectors_.binary_data_.data() + current_num_vectors_added * SingleVectorSize(dimension);
status = segment_writer_ptr->AddVectors("", ptr, size, vector_ids_to_add);
}
// Clear vector data
if (status.ok()) {
current_num_vectors_added += num_vectors_added;
// TODO(zhiru): remove
vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()),
std::make_move_iterator(vector_ids_to_add.end()));
} else {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << "SSVectorSource::Add failed: " + status.ToString();
}
return status;
SSVectorSource::SSVectorSource(const DataChunkPtr& chunk) : chunk_(chunk) {
}
Status
SSVectorSource::AddEntities(const milvus::segment::SegmentWriterPtr& segment_writer_ptr, int64_t dimension,
const size_t& num_entities_to_add, size_t& num_entities_added) {
SSVectorSource::Add(const milvus::segment::SegmentWriterPtr& segment_writer_ptr, const size_t& num_entities_to_add,
size_t& num_entities_added) {
// TODO: n = vectors_.vector_count_;???
uint64_t n = vectors_.vector_count_;
num_entities_added =
current_num_attrs_added + num_entities_to_add <= n ? num_entities_to_add : n - current_num_attrs_added;
IDNumbers vector_ids_to_add;
if (vectors_.id_array_.empty()) {
SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance();
Status status = id_generator.GetNextIDNumbers(num_entities_added, vector_ids_to_add);
if (!status.ok()) {
return status;
}
} else {
vector_ids_to_add.resize(num_entities_added);
for (size_t pos = current_num_attrs_added; pos < current_num_attrs_added + num_entities_added; pos++) {
vector_ids_to_add[pos - current_num_attrs_added] = vectors_.id_array_[pos];
}
}
Status status;
status = segment_writer_ptr->AddAttrs("", attr_size_, attr_data_, vector_ids_to_add);
if (status.ok()) {
current_num_attrs_added += num_entities_added;
} else {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << "Generate ids fail: " << status.message();
return status;
}
std::vector<uint8_t> vectors;
auto size = num_entities_added * dimension * sizeof(float);
vectors.resize(size);
memcpy(vectors.data(), vectors_.float_data_.data() + current_num_vectors_added * dimension, size);
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert into segment";
status = segment_writer_ptr->AddVectors("", vectors, vector_ids_to_add);
if (status.ok()) {
current_num_vectors_added += num_entities_added;
vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()),
std::make_move_iterator(vector_ids_to_add.end()));
}
// don't need to add current_num_attrs_added again
if (!status.ok()) {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << "SSVectorSource::Add failed: " + status.ToString();
return status;
}
return status;
}
size_t
SSVectorSource::GetNumVectorsAdded() {
return current_num_vectors_added;
}
size_t
SSVectorSource::SingleVectorSize(uint16_t dimension) {
if (!vectors_.float_data_.empty()) {
return dimension * FLOAT_TYPE_SIZE;
} else if (!vectors_.binary_data_.empty()) {
return dimension / 8;
}
return 0;
}
size_t
SSVectorSource::SingleEntitySize(uint16_t dimension) {
// TODO(yukun) add entity type and size compute
size_t size = 0;
size += dimension * FLOAT_TYPE_SIZE;
auto nbyte_it = attr_nbytes_.begin();
for (; nbyte_it != attr_nbytes_.end(); ++nbyte_it) {
size += nbyte_it->second;
}
return size;
uint64_t n = chunk_->count_;
num_entities_added = current_num_added_ + num_entities_to_add <= n ? num_entities_to_add : n - current_num_added_;
// IDNumbers vector_ids_to_add;
// if (vectors_.id_array_.empty()) {
// SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance();
// Status status = id_generator.GetNextIDNumbers(num_entities_added, vector_ids_to_add);
// if (!status.ok()) {
// return status;
// }
// } else {
// vector_ids_to_add.resize(num_entities_added);
// for (size_t pos = current_num_attrs_added; pos < current_num_attrs_added + num_entities_added; pos++) {
// vector_ids_to_add[pos - current_num_attrs_added] = vectors_.id_array_[pos];
// }
// }
//
// Status status;
// status = segment_writer_ptr->AddAttrs("", attr_size_, attr_data_, vector_ids_to_add);
//
// if (status.ok()) {
// current_num_attrs_added += num_entities_added;
// } else {
// LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << "Generate ids fail: " << status.message();
// return status;
// }
//
// std::vector<uint8_t> vectors;
// auto size = num_entities_added * dimension * sizeof(float);
// vectors.resize(size);
// memcpy(vectors.data(), vectors_.float_data_.data() + current_num_vectors_added * dimension, size);
// LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert into segment";
// status = segment_writer_ptr->AddVectors("", vectors, vector_ids_to_add);
// if (status.ok()) {
// current_num_vectors_added += num_entities_added;
// vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()),
// std::make_move_iterator(vector_ids_to_add.end()));
// }
//
// // don't need to add current_num_attrs_added again
// if (!status.ok()) {
// LOG_ENGINE_ERROR_ << LogOut("[%s][%ld]", "insert", 0) << "SSVectorSource::Add failed: " +
// status.ToString(); return status;
// }
//
// return status;
return Status::OK();
}
bool
SSVectorSource::AllAdded() {
return (current_num_vectors_added == vectors_.vector_count_);
}
IDNumbers
SSVectorSource::GetVectorIds() {
return vector_ids_;
return (current_num_added_ >= chunk_->count_);
}
} // namespace engine
......
......@@ -18,6 +18,7 @@
#include "db/IDGenerator.h"
#include "db/engine/ExecutionEngine.h"
#include "db/insert/SSMemManager.h"
#include "segment/SegmentWriter.h"
#include "utils/Status.h"
......@@ -28,44 +29,18 @@ namespace engine {
class SSVectorSource {
public:
explicit SSVectorSource(VectorsData vectors);
SSVectorSource(VectorsData vectors, const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data);
Status
Add(const segment::SegmentWriterPtr& segment_writer_ptr, int64_t dimension, const size_t& num_vectors_to_add,
size_t& num_vectors_added);
explicit SSVectorSource(const DataChunkPtr& chunk);
Status
AddEntities(const segment::SegmentWriterPtr& segment_writer_ptr, int64_t dimension, const size_t& num_attrs_to_add,
size_t& num_attrs_added);
size_t
GetNumVectorsAdded();
size_t
SingleVectorSize(uint16_t dimension);
size_t
SingleEntitySize(uint16_t dimension);
Add(const segment::SegmentWriterPtr& segment_writer_ptr, const size_t& num_attrs_to_add, size_t& num_attrs_added);
bool
AllAdded();
IDNumbers
GetVectorIds();
private:
VectorsData vectors_;
IDNumbers vector_ids_;
const std::unordered_map<std::string, uint64_t> attr_nbytes_;
std::unordered_map<std::string, uint64_t> attr_size_;
std::unordered_map<std::string, std::vector<uint8_t>> attr_data_;
DataChunkPtr chunk_;
size_t current_num_vectors_added;
size_t current_num_attrs_added;
size_t current_num_added_ = 0;
}; // SSVectorSource
using SSVectorSourcePtr = std::shared_ptr<SSVectorSource>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册