提交 7d812225 编写于 作者: F FluorineDog 提交者: yefu.chen

Support span in SegmentGrowing, refine vector_trait

Signed-off-by: NFluorineDog <guilin.gou@zilliz.com>
上级 04e20627
......@@ -86,14 +86,8 @@ class Span<T, typename std::enable_if_t<std::is_fundamental_v<T>>> {
const int64_t row_count_;
};
namespace segcore {
class VectorTrait;
class FloatVector;
class BinaryVector;
} // namespace segcore
template <typename VectorType>
class Span<VectorType, typename std::enable_if_t<std::is_base_of_v<segcore::VectorTrait, VectorType>>> {
class Span<VectorType, typename std::enable_if_t<std::is_base_of_v<VectorTrait, VectorType>>> {
public:
using embedded_type = typename VectorType::embedded_type;
......
......@@ -76,3 +76,5 @@ using FieldName = fluent::NamedType<std::string, struct FieldNameTag, fluent::Co
using FieldOffset = fluent::NamedType<int64_t, struct FieldOffsetTag, fluent::Comparable, fluent::Hashable>;
} // namespace milvus
#include "VectorTrait.h"
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "Types.h"
namespace milvus {
class VectorTrait {};
class FloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
};
class BinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr auto metric_type = DataType::VECTOR_BINARY;
};
template <typename VectorType>
inline constexpr int64_t
get_element_sizeof(int64_t dim) {
static_assert(std::is_base_of_v<VectorType, VectorTrait>);
if constexpr (std::is_same_v<VectorType, FloatVector>) {
return dim * sizeof(float);
} else {
return dim / 8;
}
}
template <typename T>
constexpr bool IsVector = std::is_base_of_v<VectorTrait, T>;
template <typename T>
constexpr bool IsScalar = std::is_fundamental_v<T>;
template <typename T, typename Enabled = void>
struct EmbeddedTypeImpl;
template <typename T>
struct EmbeddedTypeImpl<T, std::enable_if_t<IsScalar<T>>> {
using type = T;
};
template <typename T>
struct EmbeddedTypeImpl<T, std::enable_if_t<IsVector<T>>> {
using type = std::conditional_t<std::is_same_v<T, FloatVector>, float, uint8_t>;
};
template <typename T>
using EmbeddedType = typename EmbeddedTypeImpl<T>::type;
} // namespace milvus
......@@ -92,7 +92,6 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
final_qr.merge(sub_qr);
}
using segcore::FloatVector;
auto vec_ptr = record.get_entity<FloatVector>(vecfield_offset);
// step 4: brute force search where small indexing is unavailable
......@@ -165,7 +164,6 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment,
// TODO: use QuerySubResult instead
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data};
using segcore::BinaryVector;
auto vec_ptr = record.get_entity<BinaryVector>(vecfield_offset);
auto max_indexed_id = 0;
......
......@@ -23,6 +23,7 @@
#include "utils/tools.h"
#include <boost/container/vector.hpp>
#include "common/Types.h"
#include "common/Span.h"
namespace milvus::segcore {
......@@ -82,6 +83,9 @@ class VectorBase {
virtual void
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0;
virtual SpanBase
get_span_base(int64_t chunk_id) const = 0;
int64_t
get_chunk_size() const {
return chunk_size_;
......@@ -104,6 +108,9 @@ class ConcurrentVectorImpl : public VectorBase {
ConcurrentVectorImpl&
operator=(const ConcurrentVectorImpl&) = delete;
using TraitType =
std::conditional_t<is_scalar, Type, std::conditional_t<std::is_same_v<Type, float>, FloatVector, BinaryVector>>;
public:
explicit ConcurrentVectorImpl(ssize_t dim, int64_t chunk_size) : VectorBase(chunk_size), Dim(is_scalar ? 1 : dim) {
Assert(is_scalar ? dim == 1 : dim != 1);
......@@ -115,6 +122,25 @@ class ConcurrentVectorImpl : public VectorBase {
chunks_.emplace_to_at_least(chunk_count, Dim * chunk_size_);
}
Span<TraitType>
get_span(int64_t chunk_id) const {
auto& chunk = get_chunk(chunk_id);
if constexpr (is_scalar) {
return Span<TraitType>(chunk.data(), chunk_size_);
} else if constexpr (std::is_same_v<Type, int64_t> || std::is_same_v<Type, int>) {
// only for testing
PanicInfo("unimplemented");
} else {
static_assert(std::is_same_v<typename TraitType::embedded_type, Type>);
return Span<TraitType>(chunk.data(), chunk_size_, Dim);
}
}
SpanBase
get_span_base(int64_t chunk_id) const override {
return get_span(chunk_id);
}
void
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override {
set_data(element_offset, static_cast<const Type*>(source), element_count);
......@@ -206,25 +232,12 @@ class ConcurrentVectorImpl : public VectorBase {
template <typename Type>
class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
public:
static_assert(std::is_fundamental_v<Type>);
explicit ConcurrentVector(int64_t chunk_size)
: ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl(1, chunk_size) {
}
};
class VectorTrait {};
class FloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
};
class BinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr auto metric_type = DataType::VECTOR_BINARY;
};
template <>
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {
public:
......
......@@ -76,7 +76,7 @@ IndexingRecord::UpdateResourceAck(int64_t chunk_ack, const InsertRecord& record)
// std::thread([this, old_ack, chunk_ack, &record] {
for (auto& [field_offset, entry] : entries_) {
auto vec_base = record.get_base_entity(field_offset);
entry->BuildIndexRange(old_ack, chunk_ack, vec_base.get());
entry->BuildIndexRange(old_ack, chunk_ack, vec_base);
}
finished_ack_.AddSegment(old_ack, chunk_ack);
// }).detach();
......
......@@ -28,7 +28,7 @@ struct InsertRecord {
auto
get_base_entity(FieldOffset field_offset) const {
auto ptr = entity_vec_[field_offset.get()];
auto ptr = entity_vec_[field_offset.get()].get();
return ptr;
}
......@@ -36,7 +36,7 @@ struct InsertRecord {
auto
get_entity(FieldOffset field_offset) const {
auto base_ptr = get_base_entity(field_offset);
auto ptr = std::dynamic_pointer_cast<const ConcurrentVector<Type>>(base_ptr);
auto ptr = dynamic_cast<const ConcurrentVector<Type>*>(base_ptr);
Assert(ptr);
return ptr;
}
......@@ -45,7 +45,7 @@ struct InsertRecord {
auto
get_entity(FieldOffset field_offset) {
auto base_ptr = get_base_entity(field_offset);
auto ptr = std::dynamic_pointer_cast<ConcurrentVector<Type>>(base_ptr);
auto ptr = dynamic_cast<ConcurrentVector<Type>*>(base_ptr);
Assert(ptr);
return ptr;
}
......@@ -54,17 +54,17 @@ struct InsertRecord {
void
insert_entity(int64_t chunk_size) {
static_assert(std::is_fundamental_v<Type>);
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<Type>>(chunk_size));
entity_vec_.emplace_back(std::make_unique<ConcurrentVector<Type>>(chunk_size));
}
template <typename VectorType>
void
insert_entity(int64_t dim, int64_t chunk_size) {
static_assert(std::is_base_of_v<VectorTrait, VectorType>);
entity_vec_.emplace_back(std::make_shared<ConcurrentVector<VectorType>>(dim, chunk_size));
entity_vec_.emplace_back(std::make_unique<ConcurrentVector<VectorType>>(dim, chunk_size));
}
private:
std::vector<std::shared_ptr<VectorBase>> entity_vec_;
std::vector<std::unique_ptr<VectorBase>> entity_vec_;
};
} // namespace milvus::segcore
......@@ -299,4 +299,16 @@ SegmentGrowingImpl::LoadIndexing(const LoadIndexInfo& info) {
return Status::OK();
}
SpanBase
SegmentGrowingImpl::chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const {
auto vec = get_insert_record().get_base_entity(field_offset);
return vec->get_span_base(chunk_id);
}
int64_t
SegmentGrowingImpl::get_safe_num_chunk() const {
auto size = get_insert_record().ack_responder_.GetAck();
return upper_div(size, chunk_size_);
}
} // namespace milvus::segcore
......@@ -112,9 +112,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
}
int64_t
get_num_chunk() const override {
PanicInfo("unimplemented");
}
get_safe_num_chunk() const override;
Status
LoadIndexing(const LoadIndexInfo& info) override;
......@@ -139,9 +137,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
protected:
SpanBase
chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const override {
PanicInfo("unimplemented");
}
chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const override;
private:
int64_t chunk_size_;
......
......@@ -44,16 +44,16 @@ class SegmentInternalInterface : public SegmentInterface {
get_schema() const = 0;
virtual int64_t
get_num_chunk() const = 0;
get_safe_num_chunk() const = 0;
template <typename T>
Span<T>
chunk_data(FieldOffset field_offset, int64_t chunk_id) const {
auto span = chunk_data_impl(field_offset, chunk_id);
return static_cast<Span<T>>(span);
return static_cast<Span<T>>(chunk_data_impl(field_offset, chunk_id));
}
protected:
// blob and row_count
virtual SpanBase
chunk_data_impl(FieldOffset field_offset, int64_t chunk_id) const = 0;
};
......
......@@ -17,6 +17,7 @@ set(MILVUS_TEST_FILES
test_sealed.cpp
test_reduce.cpp
test_interface.cpp
test_span.cpp
)
add_executable(all_tests
${MILVUS_TEST_FILES}
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include "utils/tools.h"
#include "test_utils/DataGen.h"
#include "segcore/SegmentGrowing.h"
TEST(Span, Naive) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
int64_t N = 1000 * 1000;
constexpr int64_t chunk_size = 32 * 1024;
auto schema = std::make_shared<Schema>();
schema->AddDebugField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard);
schema->AddDebugField("age", DataType::FLOAT);
auto dataset = DataGen(schema, N);
auto segment = CreateGrowingSegment(schema, chunk_size);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto vec_ptr = dataset.get_col<uint8_t>(0);
auto age_ptr = dataset.get_col<float>(1);
SegmentInternalInterface& interface = *segment;
auto num_chunk = interface.get_safe_num_chunk();
ASSERT_EQ(num_chunk, upper_div(N, chunk_size));
auto row_count = interface.get_row_count();
ASSERT_EQ(N, row_count);
for (auto chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto vec_span = interface.chunk_data<BinaryVector>(FieldOffset(0), chunk_id);
auto age_span = interface.chunk_data<float>(FieldOffset(1), chunk_id);
auto begin = chunk_id * chunk_size;
auto end = std::min((chunk_id + 1) * chunk_size, N);
auto chunk_size = end - begin;
for (int i = 0; i < chunk_size * 512 / 8; ++i) {
ASSERT_EQ(vec_span.data()[i], vec_ptr[i + begin * 512 / 8]);
}
for (int i = 0; i < chunk_size; ++i) {
ASSERT_EQ(age_span.data()[i], age_ptr[i + begin]);
}
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册