提交 63c8f60c 编写于 作者: F FluorineDog 提交者: yefu.chen

Enable term parser and executor

Signed-off-by: NFluorineDog <guilin.gou@zilliz.com>
上级 6412ebc0
......@@ -56,3 +56,6 @@ cmake_build/
.DS_Store
*.swp
cwrapper_build
**/.clangd/*
**/compile_commands.json
**/.lint
## Binlog
InsertBinlog、DeleteBinlog、DDLBinlog
Binlog is stored in a columnar storage format, every column in schema should be stored in a individual file. Timestamp, schema, row id and primary key allocated by system are four special columns. Schema column records the DDL of the collection.
## Event format
Binlog file consists of 4 bytes magic number and a series of events. The first event must be descriptor event.
### Event format
```
+=====================================+
| event | timestamp 0 : 8 | create timestamp
| header +----------------------------+
| | type_code 8 : 1 | event type code
| +----------------------------+
| | server_id 9 : 4 | write node id
| +----------------------------+
| | event_length 13 : 4 | length of event, including header and data
| +----------------------------+
| | next_position 17 : 4 | offset of next event from the start of file
| +----------------------------+
| | extra_headers 21 : x-21 | reserved part
+=====================================+
| event | fixed part x : y |
| data +----------------------------+
| | variable part |
+=====================================+
```
### Descriptor Event format
```
+=====================================+
| event | timestamp 0 : 8 | create timestamp
| header +----------------------------+
| | type_code 8 : 1 | event type code
| +----------------------------+
| | server_id 9 : 4 | write node id
| +----------------------------+
| | event_length 13 : 4 | length of event, including header and data
| +----------------------------+
| | next_position 17 : 4 | offset of next event from the start of file
+=====================================+
| event | binlog_version 21 : 2 | binlog version
| data +----------------------------+
| | server_version 23 : 8 | write node version
| +----------------------------+
| | commit_id 31 : 8 | commit id of the programe in git
| +----------------------------+
| | header_length 39 : 1 | header length of other event
| +----------------------------+
| | collection_id 40 : 8 | collection id
| +----------------------------+
| | partition_id 48 : 8 | partition id (schema column does not need)
| +----------------------------+
| | segment_id 56 : 8 | segment id (schema column does not need)
| +----------------------------+
| | start_timestamp 64 : 1 | minimum timestamp allocated by master of all events in this file
| +----------------------------+
| | end_timestamp 65 : 1 | maximum timestamp allocated by master of all events in this file
| +----------------------------+
| | post-header 66 : n | array of n bytes, one byte per event type that the server knows about
| | lengths for all |
| | event types |
+=====================================+
```
### Type code
```
DESCRIPTOR_EVENT
INSERT_EVENT
DELETE_EVENT
CREATE_COLLECTION_EVENT
DROP_COLLECTION_EVENT
CREATE_PARTITION_EVENT
DROP_PARTITION_EVENT
```
DESCRIPTOR_EVENT must appear in all column files and always be the first event.
INSERT_EVENT 可以出现在除DDL binlog文件外的其他列的binlog
DELETE_EVENT 只能用于primary key 的binlog文件(目前只有按照primary key删除)
CREATE_COLLECTION_EVENT、DROP_COLLECTION_EVENT、CREATE_PARTITION_EVENT、DROP_PARTITION_EVENT 只出现在DDL binlog文件
### Event data part
```
event data part
INSERT_EVENT:
+================================================+
| event | fixed | start_timestamp x : 8 | min timestamp in this event
| data | part +------------------------------+
| | | end_timestamp x+8 : 8 | max timestamp in this event
| | +------------------------------+
| | | reserved x+16 : y-x-16 | reserved part
| +--------+------------------------------+
| |variable| parquet payloI ad | payload in parquet format
| |part | |
+================================================+
other events is similar with INSERT_EVENT
```
### Example
Schema
​ string | int | float(optional) | vector(512)
Request:
​ InsertRequest rows(1W)
​ DeleteRequest pk=1
​ DropPartition partitionTag="abc"
insert binlogs:
​ rowid, pk, ts, string, int, float, vector 6 files
​ all events are INSERT_EVENT
​ float column file contains some NULL value
delete binlogs:
​ pk, ts 2 files
​ pk's events are DELETE_EVENT, ts's events are INSERT_EVENT
DDL binlogs:
​ ddl, ts
​ ddl's event is DROP_PARTITION_EVENT, ts's event is INSERT_EVENT
C++ interface
```c++
typedef void* CPayloadWriter
typedef struct CBuffer {
char* data;
int length;
} CBuffer
typedef struct CStatus {
int error_code;
const char* error_msg;
} CStatus
// C++ interface
// writer
CPayloadWriter NewPayloadWriter(int columnType);
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
CStatus AddInt64ToPayload(CPayloadWriter payloadWriter, int64_t *values, int length);
CStatus AddFloatToPayload(CPayloadWriter payloadWriter, float *values, int length);
CStatus AddDoubleToPayload(CPayloadWriter payloadWriter, double *values, int length);
CStatus AddOneStringToPayload(CPayloadWriter payloadWriter, char *cstr, int str_size);
CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_t *values, int dimension, int length);
CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *values, int dimension, int length);
CStatus FinishPayloadWriter(CPayloadWriter payloadWriter);
CBuffer GetPayloadBufferFromWriter(CPayloadWriter payloadWriter);
int GetPayloadLengthFromWriter(CPayloadWriter payloadWriter);
CStatus ReleasePayloadWriter(CPayloadWriter handler);
// reader
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
CStatus GetInt64FromPayload(CPayloadReader payloadReader, int64_t **values, int *length);
CStatus GetFloatFromPayload(CPayloadReader payloadReader, float **values, int *length);
CStatus GetDoubleFromPayload(CPayloadReader payloadReader, double **values, int *length);
CStatus GetOneStringFromPayload(CPayloadReader payloadReader, int idx, char **cstr, int *str_size);
CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader, uint8_t **values, int *dimension, int *length);
CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, float **values, int *dimension, int *length);
int GetPayloadLengthFromReader(CPayloadReader payloadReader);
CStatus ReleasePayloadReader(CPayloadReader payloadReader);
```
......@@ -38,7 +38,7 @@ static auto map = [] {
MetricType
GetMetricType(const std::string& type_name) {
auto real_name = to_lower_copy(type_name);
AssertInfo(map.left.count(real_name), "metric type not found: " + type_name);
AssertInfo(map.left.count(real_name), "metric type not found: (" + type_name + ")");
return map.left.at(real_name);
}
......
......@@ -13,6 +13,8 @@
#include "utils/Types.h"
#include <faiss/MetricType.h>
#include <string>
#include <boost/align/aligned_allocator.hpp>
#include <vector>
namespace milvus {
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
......@@ -24,4 +26,15 @@ using MetricType = faiss::MetricType;
faiss::MetricType
GetMetricType(const std::string& type);
// NOTE: dependent type
// used at meta-template programming
template <class...>
constexpr std::true_type always_true{};
template <class...>
constexpr std::false_type always_false{};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
} // namespace milvus
......@@ -70,8 +70,6 @@ to_lower(const std::string& raw) {
return data;
}
template <class...>
constexpr std::false_type always_false{};
template <typename T>
std::unique_ptr<Expr>
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
......@@ -85,31 +83,62 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
auto op = RangeExpr::mapping_.at(op_name);
if constexpr (std::is_integral_v<T>) {
if constexpr (std::is_same_v<T, bool>) {
Assert(item.value().is_boolean());
} else if constexpr (std::is_integral_v<T>) {
Assert(item.value().is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(item.value().is_number());
} else {
static_assert(always_false<T>, "unsupported type");
__builtin_unreachable();
}
T value = item.value();
expr->conditions_.emplace_back(op, value);
}
std::sort(expr->conditions_.begin(), expr->conditions_.end());
return expr;
}
template <typename T>
std::unique_ptr<Expr>
ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
auto expr = std::make_unique<TermExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
Assert(body.is_array());
expr->field_id_ = field_name;
expr->data_type_ = data_type;
for (auto& value : body) {
if constexpr (std::is_same_v<T, bool>) {
Assert(value.is_boolean());
} else if constexpr (std::is_integral_v<T>) {
Assert(value.is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(value.is_number());
} else {
static_assert(always_false<T>, "unsupported type");
__builtin_unreachable();
}
T real_value = value;
expr->terms_.push_back(real_value);
}
std::sort(expr->terms_.begin(), expr->terms_.end());
return expr;
}
std::unique_ptr<Expr>
ParseRangeNode(const Schema& schema, const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = out_iter.key();
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
PanicInfo("bool is not supported in Range node");
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
return ParseRangeNodeImpl<bool>(schema, field_name, body);
}
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
......@@ -128,6 +157,42 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
}
}
static std::unique_ptr<Expr>
ParseTermNode(const Schema& schema, const Json& out_body) {
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = out_iter.key();
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
return ParseTermNodeImpl<bool>(schema, field_name, body);
}
case DataType::INT8: {
return ParseTermNodeImpl<int8_t>(schema, field_name, body);
}
case DataType::INT16: {
return ParseTermNodeImpl<int16_t>(schema, field_name, body);
}
case DataType::INT32: {
return ParseTermNodeImpl<int32_t>(schema, field_name, body);
}
case DataType::INT64: {
return ParseTermNodeImpl<int64_t>(schema, field_name, body);
}
case DataType::FLOAT: {
return ParseTermNodeImpl<float>(schema, field_name, body);
}
case DataType::DOUBLE: {
return ParseTermNodeImpl<double>(schema, field_name, body);
}
default: {
PanicInfo("unsupported data_type");
}
}
}
static std::unique_ptr<Plan>
CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
auto plan = std::make_unique<Plan>(schema);
......@@ -143,6 +208,10 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
if (pack.contains("vector")) {
auto& out_body = pack.at("vector");
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
} else if (pack.contains("term")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("term");
predicate = ParseTermNode(schema, out_body);
} else if (pack.contains("range")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("range");
......
......@@ -20,7 +20,6 @@
#include <map>
#include <string>
#include <vector>
#include <boost/align/aligned_allocator.hpp>
namespace milvus::query {
using Json = nlohmann::json;
......@@ -39,9 +38,6 @@ struct Plan {
// TODO: add move extra info
};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
struct Placeholder {
// milvus::proto::service::PlaceholderGroup group_;
std::string tag_;
......
......@@ -27,7 +27,7 @@ create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk
return nullptr;
}
auto& bitmaps = *bitmaps_opt.value();
auto& src_vec = bitmaps.at(chunk_id);
auto src_vec = ~bitmaps.at(chunk_id);
auto dst = std::make_shared<faiss::ConcurrentBitset>(src_vec.size());
auto iter = reinterpret_cast<BitmapChunk::block_type*>(dst->mutable_data());
......
......@@ -58,6 +58,10 @@ class ExecExprVisitor : ExprVisitor {
auto
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
template <typename T>
auto
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
private:
segcore::SegmentSmallIndex& segment_;
std::optional<RetType> ret_;
......
......@@ -46,6 +46,10 @@ class ExecExprVisitor : ExprVisitor {
auto
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
template <typename T>
auto
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
private:
segcore::SegmentSmallIndex& segment_;
std::optional<RetType> ret_;
......@@ -63,11 +67,6 @@ ExecExprVisitor::visit(BoolBinaryExpr& expr) {
PanicInfo("unimplemented");
}
void
ExecExprVisitor::visit(TermExpr& expr) {
PanicInfo("unimplemented");
}
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_func, ElementFunc element_func)
......@@ -84,17 +83,17 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_fu
auto& indexing_record = segment_.get_indexing_record();
const segcore::ScalarIndexingEntry<T>& entry = indexing_record.get_scalar_entry<T>(field_offset);
RetType results(vec.chunk_size());
RetType results(vec.num_chunk());
auto indexing_barrier = indexing_record.get_finished_ack();
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
auto& result = results[chunk_id];
auto indexing = entry.get_indexing(chunk_id);
auto data = index_func(indexing);
result = ~std::move(*data);
result = std::move(*data);
Assert(result.size() == segcore::DefaultElementPerChunk);
}
for (auto chunk_id = indexing_barrier; chunk_id < vec.chunk_size(); ++chunk_id) {
for (auto chunk_id = indexing_barrier; chunk_id < vec.num_chunk(); ++chunk_id) {
auto& result = results[chunk_id];
result.resize(segcore::DefaultElementPerChunk);
auto chunk = vec.get_chunk(chunk_id);
......@@ -126,32 +125,32 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
switch (op) {
case OpType::Equal: {
auto index_func = [val](Index* index) { return index->In(1, &val); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x == val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x == val); });
}
case OpType::NotEqual: {
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x != val); });
}
case OpType::GreaterEqual: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x >= val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x >= val); });
}
case OpType::GreaterThan: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x > val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x > val); });
}
case OpType::LessEqual: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x <= val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x <= val); });
}
case OpType::LessThan: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x < val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x < val); });
}
default: {
PanicInfo("unsupported range node");
......@@ -167,16 +166,16 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
if (false) {
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x < val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x < val2); });
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x <= val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x <= val2); });
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x < val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x < val2); });
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x <= val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x <= val2); });
} else {
PanicInfo("unsupported range node");
}
......@@ -226,4 +225,79 @@ ExecExprVisitor::visit(RangeExpr& expr) {
ret_ = std::move(ret);
}
template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType {
auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
auto& records = segment_.get_insert_record();
auto data_type = expr.data_type_;
auto& schema = segment_.get_schema();
auto field_offset_opt = schema.get_offset(expr.field_id_);
Assert(field_offset_opt);
auto field_offset = field_offset_opt.value();
auto& field_meta = schema[field_offset];
auto vec_ptr = records.get_entity<T>(field_offset);
auto& vec = *vec_ptr;
auto num_chunk = vec.num_chunk();
RetType bitsets;
auto N = records.ack_responder_.GetAck();
// small batch
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto& chunk = vec.get_chunk(chunk_id);
auto size = chunk_id == num_chunk - 1 ? N - chunk_id * segcore::DefaultElementPerChunk
: segcore::DefaultElementPerChunk;
boost::dynamic_bitset<> bitset(segcore::DefaultElementPerChunk);
for (int i = 0; i < size; ++i) {
auto value = chunk[i];
bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value);
bitset[i] = is_in;
}
bitsets.emplace_back(std::move(bitset));
}
return bitsets;
}
void
ExecExprVisitor::visit(TermExpr& expr) {
auto& field_meta = segment_.get_schema()[expr.field_id_];
Assert(expr.data_type_ == field_meta.get_data_type());
RetType ret;
switch (expr.data_type_) {
case DataType::BOOL: {
ret = ExecTermVisitorImpl<bool>(expr);
break;
}
case DataType::INT8: {
ret = ExecTermVisitorImpl<int8_t>(expr);
break;
}
case DataType::INT16: {
ret = ExecTermVisitorImpl<int16_t>(expr);
break;
}
case DataType::INT32: {
ret = ExecTermVisitorImpl<int32_t>(expr);
break;
}
case DataType::INT64: {
ret = ExecTermVisitorImpl<int64_t>(expr);
break;
}
case DataType::FLOAT: {
ret = ExecTermVisitorImpl<float>(expr);
break;
}
case DataType::DOUBLE: {
ret = ExecTermVisitorImpl<double>(expr);
break;
}
default:
PanicInfo("unsupported");
}
ret_ = std::move(ret);
}
} // namespace milvus::query
......@@ -196,7 +196,7 @@ class ConcurrentVectorImpl : public VectorBase {
}
ssize_t
chunk_size() const {
num_chunk() const {
return chunks_.size();
}
......
......@@ -24,7 +24,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
Assert(source);
auto chunk_size = source->chunk_size();
auto chunk_size = source->num_chunk();
assert(ack_end <= chunk_size);
auto conf = get_build_conf();
data_.grow_to_at_least(ack_end);
......@@ -87,7 +87,7 @@ void
ScalarIndexingEntry<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
Assert(source);
auto chunk_size = source->chunk_size();
auto chunk_size = source->num_chunk();
assert(ack_end <= chunk_size);
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
......
......@@ -467,16 +467,16 @@ SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
auto dim = field.get_dim();
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
auto chunk_size = record_.uids_.chunk_size();
auto chunk_size = record_.uids_.num_chunk();
auto& uids = record_.uids_;
auto entities = record_.get_entity<FloatVector>(offset);
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
auto entities_chunk = entities->get_chunk(chunk_id).data();
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto& ds : datasets) {
......
......@@ -241,10 +241,10 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
auto entities = record_.get_entity<FloatVector>(offset);
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
auto entities_chunk = entities->get_chunk(chunk_id).data();
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto& ds : datasets) {
......
......@@ -26,4 +26,5 @@ target_link_libraries(all_tests
pthread
milvus_utils
)
install (TARGETS all_tests DESTINATION unittest)
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
......@@ -52,7 +52,7 @@ TEST(ConcurrentVector, TestSingle) {
c_vec.set_data(total_count, vec.data(), insert_size);
total_count += insert_size;
}
ASSERT_EQ(c_vec.chunk_size(), (total_count + 31) / 32);
ASSERT_EQ(c_vec.num_chunk(), (total_count + 31) / 32);
for (int i = 0; i < total_count; ++i) {
for (int d = 0; d < dim; ++d) {
auto std_data = d + i * dim;
......
......@@ -321,7 +321,88 @@ TEST(Expr, TestRange) {
auto ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = !ref_func(val);
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}
}
TEST(Expr, TestTerm) {
using namespace milvus::query;
using namespace milvus::segcore;
auto vec_2k_3k = [] {
std::string buf = "[";
for (int i = 2000; i < 3000 - 1; ++i) {
buf += std::to_string(i) + ", ";
}
buf += std::to_string(2999) + "]";
return buf;
}();
std::vector<std::tuple<std::string, std::function<bool(int)>>> testcases = {
{R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }},
{R"([2000])", [](int v) { return v == 2000; }},
{R"([3000])", [](int v) { return v == 3000; }},
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
};
std::string dsl_string_tmp = R"(
{
"bool": {
"must": [
{
"term": {
"age": @@@@
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("age", DataType::INT32);
auto seg = CreateSegment(schema);
int N = 10000;
std::vector<int> age_col;
int num_iters = 100;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGen(schema, N, iter);
auto new_age_col = raw_data.get_col<int>(1);
age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end());
seg->PreInsert(N);
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
}
auto seg_promote = dynamic_cast<SegmentSmallIndex*>(seg.get());
ExecExprVisitor visitor(*seg_promote);
for (auto [clause, ref_func] : testcases) {
auto loc = dsl_string_tmp.find("@@@@");
auto dsl_string = dsl_string_tmp;
dsl_string.replace(loc, 4, clause);
auto plan = CreatePlan(*schema, dsl_string);
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
auto ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}
......
......@@ -31,6 +31,14 @@ struct GeneratedData {
memcpy(ret.data(), target.data(), target.size());
return ret;
}
template <typename T>
auto
get_mutable_col(int index) {
auto& target = cols_.at(index);
assert(target.size() == row_ids_.size() * sizeof(T));
auto ptr = reinterpret_cast<T*>(target.data());
return ptr;
}
private:
GeneratedData() = default;
......@@ -58,6 +66,9 @@ GeneratedData::generate_rows(int N, SchemaPtr schema) {
}
}
rows_ = std::move(result);
raw_.raw_data = rows_.data();
raw_.sizeof_per_row = schema->get_total_sizeof();
raw_.count = N;
}
inline GeneratedData
......@@ -129,14 +140,12 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
}
GeneratedData res;
res.cols_ = std::move(cols);
res.generate_rows(N, schema);
for (int i = 0; i < N; ++i) {
res.row_ids_.push_back(i);
res.timestamps_.push_back(i);
}
res.raw_.raw_data = res.rows_.data();
res.raw_.sizeof_per_row = schema->get_total_sizeof();
res.raw_.count = N;
res.generate_rows(N, schema);
return std::move(res);
}
......
......@@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_
st.error_msg = ErrorMsg("payload has finished");
return st;
}
auto ast = builder->AppendValues(values, (dimension / 8) * length);
auto ast = builder->AppendValues(values, length);
if (!ast.ok()) {
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
st.error_msg = ErrorMsg(ast.message());
......@@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *
st.error_msg = ErrorMsg("payload has finished");
return st;
}
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), dimension * length * sizeof(float));
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), length);
if (!ast.ok()) {
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
st.error_msg = ErrorMsg(ast.message());
......@@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader,
return st;
}
*dimension = array->byte_width() * 8;
*length = array->length() / array->byte_width();
*length = array->length();
*values = (uint8_t *) array->raw_values();
return st;
}
......@@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
return st;
}
*dimension = array->byte_width() / sizeof(float);
*length = array->length() / array->byte_width();
*length = array->length();
*values = (float *) array->raw_values();
return st;
}
......@@ -478,12 +478,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) {
auto p = reinterpret_cast<wrapper::PayloadReader *>(payloadReader);
if (p->array == nullptr) return 0;
auto ba = std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(p->array);
if (ba == nullptr) {
return p->array->length();
} else {
return ba->length() / ba->byte_width();
}
return p->array->length();
}
extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) {
......
......@@ -5,6 +5,7 @@ extern "C" {
#endif
#include <stdint.h>
#include <stdbool.h>
typedef void *CPayloadWriter;
......@@ -19,7 +20,7 @@ typedef struct CStatus {
} CStatus;
CPayloadWriter NewPayloadWriter(int columnType);
//CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
......@@ -39,7 +40,7 @@ CStatus ReleasePayloadWriter(CPayloadWriter handler);
typedef void *CPayloadReader;
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
//CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
......@@ -55,4 +56,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader);
#ifdef __cplusplus
}
#endif
\ No newline at end of file
#endif
......@@ -70,38 +70,38 @@ TEST(wrapper, inoutstream) {
ASSERT_EQ(inarray->Value(4), 5);
}
//TEST(wrapper, boolean) {
// auto payload = NewPayloadWriter(ColumnType::BOOL);
// bool data[] = {true, false, true, false};
//
// auto st = AddBooleanToPayload(payload, data, 4);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// st = FinishPayloadWriter(payload);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// auto cb = GetPayloadBufferFromWriter(payload);
// ASSERT_GT(cb.length, 0);
// ASSERT_NE(cb.data, nullptr);
// auto nums = GetPayloadLengthFromWriter(payload);
// ASSERT_EQ(nums, 4);
//
// auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
// bool *values;
// int length;
// st = GetBoolFromPayload(reader, &values, &length);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// ASSERT_NE(values, nullptr);
// ASSERT_EQ(length, 4);
// length = GetPayloadLengthFromReader(reader);
// ASSERT_EQ(length, 4);
// for (int i = 0; i < length; i++) {
// ASSERT_EQ(data[i], values[i]);
// }
//
// st = ReleasePayloadWriter(payload);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// st = ReleasePayloadReader(reader);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
//}
TEST(wrapper, boolean) {
auto payload = NewPayloadWriter(ColumnType::BOOL);
bool data[] = {true, false, true, false};
auto st = AddBooleanToPayload(payload, data, 4);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = FinishPayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
auto cb = GetPayloadBufferFromWriter(payload);
ASSERT_GT(cb.length, 0);
ASSERT_NE(cb.data, nullptr);
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
bool *values;
int length;
st = GetBoolFromPayload(reader, &values, &length);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
ASSERT_NE(values, nullptr);
ASSERT_EQ(length, 4);
length = GetPayloadLengthFromReader(reader);
ASSERT_EQ(length, 4);
for (int i = 0; i < length; i++) {
ASSERT_EQ(data[i], values[i]);
}
st = ReleasePayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
}
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \
auto payload = NewPayloadWriter(COLUMN_TYPE); \
......
......@@ -16,25 +16,311 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
type PayloadWriter struct {
payloadWriterPtr C.CPayloadWriter
}
type (
PayloadWriter struct {
payloadWriterPtr C.CPayloadWriter
colType schemapb.DataType
}
PayloadReader struct {
payloadReaderPtr C.CPayloadReader
colType schemapb.DataType
}
)
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
w := C.NewPayloadWriter(C.int(colType))
if w == nil {
return nil, errors.New("create Payload writer failed")
}
return &PayloadWriter{payloadWriterPtr: w}, nil
return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil
}
func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
switch len(dim) {
case 0:
switch w.colType {
case schemapb.DataType_BOOL:
val, ok := msgs.([]bool)
if !ok {
return errors.New("incorrect data type")
}
return w.AddBoolToPayload(val)
case schemapb.DataType_INT8:
val, ok := msgs.([]int8)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt8ToPayload(val)
case schemapb.DataType_INT16:
val, ok := msgs.([]int16)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt16ToPayload(val)
case schemapb.DataType_INT32:
val, ok := msgs.([]int32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt32ToPayload(val)
case schemapb.DataType_INT64:
val, ok := msgs.([]int64)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt64ToPayload(val)
case schemapb.DataType_FLOAT:
val, ok := msgs.([]float32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddFloatToPayload(val)
case schemapb.DataType_DOUBLE:
val, ok := msgs.([]float64)
if !ok {
return errors.New("incorrect data type")
}
return w.AddDoubleToPayload(val)
case schemapb.DataType_STRING:
val, ok := msgs.(string)
if !ok {
return errors.New("incorrect data type")
}
return w.AddOneStringToPayload(val)
}
case 1:
switch w.colType {
case schemapb.DataType_VECTOR_BINARY:
val, ok := msgs.([]byte)
if !ok {
return errors.New("incorrect data type")
}
return w.AddBinaryVectorToPayload(val, dim[0])
case schemapb.DataType_VECTOR_FLOAT:
val, ok := msgs.([]float32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddFloatVectorToPayload(val, dim[0])
}
default:
return errors.New("incorrect input numbers")
}
return nil
}
func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.float)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.double)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
if len(msg) == 0 {
length := len(msg)
if length == 0 {
return errors.New("can't add empty string into payload")
}
cstr := C.CString(msg)
defer C.free(unsafe.Pointer(cstr))
st := C.AddOneStringToPayload(w.payloadWriterPtr, cstr, C.int(len(msg)))
cmsg := C.CString(msg)
clength := C.int(length)
defer C.free(unsafe.Pointer(cmsg))
st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
// dimension > 0 && (%8 == 0)
func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error {
length := len(binVec)
if length <= 0 {
return errors.New("can't add empty binVec into payload")
}
if dim <= 0 {
return errors.New("dimension should be greater than 0")
}
cBinVec := (*C.uint8_t)(&binVec[0])
cDim := C.int(dim)
cLength := C.int(length / (dim / 8))
st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
// dimension > 0 && (%8 == 0)
func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error {
length := len(floatVec)
if length <= 0 {
return errors.New("can't add empty floatVec into payload")
}
if dim <= 0 {
return errors.New("dimension should be greater than 0")
}
cBinVec := (*C.float)(&floatVec[0])
cDim := C.int(dim)
cLength := C.int(length / dim)
st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
......@@ -56,13 +342,13 @@ func (w *PayloadWriter) FinishPayloadWriter() error {
}
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
//See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
pointer := unsafe.Pointer(cb.data)
length := int(cb.length)
if length <= 0 {
return nil, errors.New("empty buffer")
}
// refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
slice := (*[1 << 28]byte)(pointer)[:length:length]
return slice, nil
}
......@@ -87,16 +373,71 @@ func (w *PayloadWriter) Close() error {
return w.ReleasePayloadWriter()
}
type PayloadReader struct {
payloadReaderPtr C.CPayloadReader
}
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
if len(buf) == 0 {
return nil, errors.New("create Payload reader failed, buffer is empty")
}
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
return &PayloadReader{payloadReaderPtr: r}, nil
return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil
}
// Params:
// `idx`: String index
// Return:
// `interface{}`: all types.
// `int`: length, only meaningful to FLOAT/BINARY VECTOR type.
// `error`: error.
func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) {
switch len(idx) {
case 1:
switch r.colType {
case schemapb.DataType_STRING:
val, err := r.GetOneStringFromPayload(idx[0])
return val, 0, err
}
case 0:
switch r.colType {
case schemapb.DataType_BOOL:
val, err := r.GetBoolFromPayload()
return val, 0, err
case schemapb.DataType_INT8:
val, err := r.GetInt8FromPayload()
return val, 0, err
case schemapb.DataType_INT16:
val, err := r.GetInt16FromPayload()
return val, 0, err
case schemapb.DataType_INT32:
val, err := r.GetInt32FromPayload()
return val, 0, err
case schemapb.DataType_INT64:
val, err := r.GetInt64FromPayload()
return val, 0, err
case schemapb.DataType_FLOAT:
val, err := r.GetFloatFromPayload()
return val, 0, err
case schemapb.DataType_DOUBLE:
val, err := r.GetDoubleFromPayload()
return val, 0, err
case schemapb.DataType_VECTOR_BINARY:
return r.GetBinaryVectorFromPayload()
case schemapb.DataType_VECTOR_FLOAT:
return r.GetFloatVectorFromPayload()
default:
return nil, 0, errors.New("Unknown type")
}
default:
return nil, 0, errors.New("incorrect number of index")
}
return nil, 0, errors.New("unknown error")
}
func (r *PayloadReader) ReleasePayloadReader() error {
......@@ -110,18 +451,169 @@ func (r *PayloadReader) ReleasePayloadReader() error {
return nil
}
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
var cMsg *C.bool
var cSize C.int
st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
var cMsg *C.int8_t
var cSize C.int
st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
var cMsg *C.int16_t
var cSize C.int
st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
var cMsg *C.int32_t
var cSize C.int
st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
var cMsg *C.int64_t
var cSize C.int
st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
var cMsg *C.float
var cSize C.int
st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
var cMsg *C.double
var cSize C.int
st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
var cStr *C.char
var strSize C.int
var cSize C.int
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize)
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &strSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return "", errors.New(msg)
}
return C.GoStringN(cStr, strSize), nil
return C.GoStringN(cStr, cSize), nil
}
// ,dimension, error
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
var cMsg *C.uint8_t
var cDim C.int
var cLen C.int
st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, 0, errors.New(msg)
}
length := (cDim / 8) * cLen
slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length]
return slice, int(cDim), nil
}
// ,dimension, error
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
var cMsg *C.float
var cDim C.int
var cLen C.int
st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, 0, errors.New(msg)
}
length := cDim * cLen
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length]
return slice, int(cDim), nil
}
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
......
package storage
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestNewPayloadWriter(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
assert.Nil(t, err)
assert.NotNil(t, w)
err = w.Close()
assert.Nil(t, err)
}
func TestPayload_ReaderandWriter(t *testing.T) {
t.Run("TestBool", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_BOOL)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddBoolToPayload([]bool{false, false, false, false})
assert.Nil(t, err)
err = w.AddDataToPayload([]bool{false, false, false, false})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 8, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 8)
bools, err := r.GetBoolFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
ibools, _, err := r.GetDataFromPayload()
bools = ibools.([]bool)
assert.Nil(t, err)
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
defer r.ReleasePayloadReader()
})
t.Run("TestInt8", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT8)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt8ToPayload([]int8{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int8{4, 5, 6})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT8, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int8s, err := r.GetInt8FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
iint8s, _, err := r.GetDataFromPayload()
int8s = iint8s.([]int8)
assert.Nil(t, err)
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt16", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT16)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt16ToPayload([]int16{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int16{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT16, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int16s, err := r.GetInt16FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
iint16s, _, err := r.GetDataFromPayload()
int16s = iint16s.([]int16)
assert.Nil(t, err)
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt32", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT32)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt32ToPayload([]int32{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int32{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT32, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int32s, err := r.GetInt32FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
iint32s, _, err := r.GetDataFromPayload()
int32s = iint32s.([]int32)
assert.Nil(t, err)
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt64", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT64)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt64ToPayload([]int64{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int64{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT64, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int64s, err := r.GetInt64FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
iint64s, _, err := r.GetDataFromPayload()
int64s = iint64s.([]int64)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
defer r.ReleasePayloadReader()
})
t.Run("TestFloat32", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_FLOAT)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
float32s, err := r.GetFloatFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
ifloat32s, _, err := r.GetDataFromPayload()
float32s = ifloat32s.([]float32)
assert.Nil(t, err)
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
defer r.ReleasePayloadReader()
})
t.Run("TestDouble", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_DOUBLE)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
float64s, err := r.GetDoubleFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
ifloat64s, _, err := r.GetDataFromPayload()
float64s = ifloat64s.([]float64)
assert.Nil(t, err)
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
defer r.ReleasePayloadReader()
})
t.Run("TestAddOneString", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddOneStringToPayload("hello0")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello1")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello2")
assert.Nil(t, err)
err = w.AddDataToPayload("hello3")
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, length, 4)
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
assert.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 4)
str0, err := r.GetOneStringFromPayload(0)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
str1, err := r.GetOneStringFromPayload(1)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
str2, err := r.GetOneStringFromPayload(2)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
str3, err := r.GetOneStringFromPayload(3)
assert.Nil(t, err)
assert.Equal(t, str3, "hello3")
istr0, _, err := r.GetDataFromPayload(0)
str0 = istr0.(string)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
istr1, _, err := r.GetDataFromPayload(1)
str1 = istr1.(string)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
istr2, _, err := r.GetDataFromPayload(2)
str2 = istr2.(string)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
istr3, _, err := r.GetDataFromPayload(3)
str3 = istr3.(string)
assert.Nil(t, err)
assert.Equal(t, str3, "hello3")
err = r.ReleasePayloadReader()
assert.Nil(t, err)
err = w.ReleasePayloadWriter()
assert.Nil(t, err)
})
t.Run("TestBinaryVector", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY)
require.Nil(t, err)
require.NotNil(t, w)
in := make([]byte, 16)
for i := 0; i < 16; i++ {
in[i] = 1
}
in2 := make([]byte, 8)
for i := 0; i < 8; i++ {
in2[i] = 1
}
err = w.AddBinaryVectorToPayload(in, 8)
assert.Nil(t, err)
err = w.AddDataToPayload(in2, 8)
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 24, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 24)
binVecs, dim, err := r.GetBinaryVectorFromPayload()
assert.Nil(t, err)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
fmt.Println(binVecs)
ibinVecs, dim, err := r.GetDataFromPayload()
assert.Nil(t, err)
binVecs = ibinVecs.([]byte)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
defer r.ReleasePayloadReader()
})
t.Run("TestFloatVector", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
assert.Nil(t, err)
err = w.AddDataToPayload([]float32{3.0, 4.0}, 1)
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 4, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 4)
floatVecs, dim, err := r.GetFloatVectorFromPayload()
assert.Nil(t, err)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
func TestPayLoadString(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello0")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello1")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello2")
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, length, 3)
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
assert.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 3)
str0, err := r.GetOneStringFromPayload(0)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
str1, err := r.GetOneStringFromPayload(1)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
str2, err := r.GetOneStringFromPayload(2)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
err = r.ReleasePayloadReader()
assert.Nil(t, err)
err = w.ReleasePayloadWriter()
assert.Nil(t, err)
ifloatVecs, dim, err := r.GetDataFromPayload()
assert.Nil(t, err)
floatVecs = ifloatVecs.([]float32)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
defer r.ReleasePayloadReader()
})
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册