utils.cpp 5.5 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

18
#include "unittest/utils.h"
X
xj.lin 已提交
19

X
xiaojun.lin 已提交
20
#include <gtest/gtest.h>
S
starlord 已提交
21 22 23
#include <memory>
#include <string>
#include <utility>
X
xj.lin 已提交
24

H
Heisenberg 已提交
25 26
INITIALIZE_EASYLOGGINGPP

S
starlord 已提交
27 28
void
InitLog() {
H
Heisenberg 已提交
29 30
    el::Configurations defaultConf;
    defaultConf.setToDefault();
S
starlord 已提交
31
    defaultConf.set(el::Level::Debug, el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
H
Heisenberg 已提交
32 33
    el::Loggers::reconfigureLogger("default", defaultConf);
}
X
xj.lin 已提交
34

S
starlord 已提交
35 36
void
DataGen::Init_with_default() {
X
xj.lin 已提交
37 38 39
    Generate(dim, nb, nq);
}

S
starlord 已提交
40 41
void
DataGen::Generate(const int& dim, const int& nb, const int& nq) {
X
xj.lin 已提交
42 43 44 45 46
    this->nb = nb;
    this->nq = nq;
    this->dim = dim;

    GenAll(dim, nb, xb, ids, nq, xq);
S
starlord 已提交
47 48
    assert(xb.size() == (size_t)dim * nb);
    assert(xq.size() == (size_t)dim * nq);
X
xj.lin 已提交
49 50 51 52

    base_dataset = generate_dataset(nb, dim, xb.data(), ids.data());
    query_dataset = generate_query_dataset(nq, dim, xq.data());
}
S
starlord 已提交
53

S
starlord 已提交
54
knowhere::DatasetPtr
S
starlord 已提交
55
DataGen::GenQuery(const int& nq) {
X
xj.lin 已提交
56
    xq.resize(nq * dim);
S
starlord 已提交
57
    for (int i = 0; i < nq * dim; ++i) {
X
xj.lin 已提交
58 59 60 61 62
        xq[i] = xb[i];
    }
    return generate_query_dataset(nq, dim, xq.data());
}

S
starlord 已提交
63 64 65
void
GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector<int64_t>& ids, const int64_t& nq,
       std::vector<float>& xq) {
X
xj.lin 已提交
66 67 68 69 70 71
    xb.resize(nb * dim);
    xq.resize(nq * dim);
    ids.resize(nb);
    GenAll(dim, nb, xb.data(), ids.data(), nq, xq.data());
}

S
starlord 已提交
72 73
void
GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq) {
X
xj.lin 已提交
74
    GenBase(dim, nb, xb, ids);
S
starlord 已提交
75
    for (int64_t i = 0; i < nq * dim; ++i) {
X
xj.lin 已提交
76 77 78 79
        xq[i] = xb[i];
    }
}

S
starlord 已提交
80 81
void
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
X
xj.lin 已提交
82 83
    for (auto i = 0; i < nb; ++i) {
        for (auto j = 0; j < dim; ++j) {
S
starlord 已提交
84
            // p_data[i * d + j] = float(base + i);
X
xj.lin 已提交
85 86 87 88 89 90 91
            xb[i * dim + j] = drand48();
        }
        xb[dim * i] += i / 1000.;
        ids[i] = i;
    }
}

S
starlord 已提交
92
FileIOReader::FileIOReader(const std::string& fname) {
X
xj.lin 已提交
93 94 95 96 97 98 99 100
    name = fname;
    fs = std::fstream(name, std::ios::in | std::ios::binary);
}

FileIOReader::~FileIOReader() {
    fs.close();
}

S
starlord 已提交
101 102 103
size_t
FileIOReader::operator()(void* ptr, size_t size) {
    fs.read(reinterpret_cast<char*>(ptr), size);
X
xj.lin 已提交
104 105 106
    return size;
}

S
starlord 已提交
107
FileIOWriter::FileIOWriter(const std::string& fname) {
X
xj.lin 已提交
108 109 110 111 112 113 114 115
    name = fname;
    fs = std::fstream(name, std::ios::out | std::ios::binary);
}

FileIOWriter::~FileIOWriter() {
    fs.close();
}

S
starlord 已提交
116 117 118
size_t
FileIOWriter::operator()(void* ptr, size_t size) {
    fs.write(reinterpret_cast<char*>(ptr), size);
X
xj.lin 已提交
119 120 121
    return size;
}

S
starlord 已提交
122
knowhere::DatasetPtr
S
starlord 已提交
123
generate_dataset(int64_t nb, int64_t dim, float* xb, int64_t* ids) {
X
xj.lin 已提交
124
    std::vector<int64_t> shape{nb, dim};
S
starlord 已提交
125 126 127 128 129 130 131 132 133 134 135 136
    auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
    std::vector<knowhere::TensorPtr> tensors{tensor};
    std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields);

    auto id_array = knowhere::ConstructInt64Array((uint8_t*)ids, nb * sizeof(int64_t));
    std::vector<knowhere::ArrayPtr> arrays{id_array};
    std::vector<knowhere::FieldPtr> array_fields{knowhere::ConstructInt64Field("id")};
    auto array_schema = std::make_shared<knowhere::Schema>(tensor_fields);

    auto dataset =
        std::make_shared<knowhere::Dataset>(std::move(arrays), array_schema, std::move(tensors), tensor_schema);
X
xj.lin 已提交
137 138 139
    return dataset;
}

S
starlord 已提交
140
knowhere::DatasetPtr
S
starlord 已提交
141
generate_query_dataset(int64_t nb, int64_t dim, float* xb) {
X
xj.lin 已提交
142
    std::vector<int64_t> shape{nb, dim};
S
starlord 已提交
143 144 145 146
    auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
    std::vector<knowhere::TensorPtr> tensors{tensor};
    std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields);
X
xj.lin 已提交
147

S
starlord 已提交
148
    auto dataset = std::make_shared<knowhere::Dataset>(std::move(tensors), tensor_schema);
X
xj.lin 已提交
149 150
    return dataset;
}
X
xiaojun.lin 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177

void
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
    auto ids = result->array()[0];
    for (auto i = 0; i < nq; i++) {
        EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
    }
}

void
PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
    auto ids = result->array()[0];
    auto dists = result->array()[1];

    std::stringstream ss_id;
    std::stringstream ss_dist;
    for (auto i = 0; i < 10; i++) {
        for (auto j = 0; j < k; ++j) {
            ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
            ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
        }
        ss_id << std::endl;
        ss_dist << std::endl;
    }
    std::cout << "id\n" << ss_id.str() << std::endl;
    std::cout << "dist\n" << ss_dist.str() << std::endl;
}