utils.cpp 5.7 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

18
#include "unittest/utils.h"
X
xj.lin 已提交
19

X
xiaojun.lin 已提交
20
#include <gtest/gtest.h>
S
starlord 已提交
21 22 23
#include <memory>
#include <string>
#include <utility>
X
xj.lin 已提交
24

H
Heisenberg 已提交
25 26
INITIALIZE_EASYLOGGINGPP

S
starlord 已提交
27 28
void
InitLog() {
H
Heisenberg 已提交
29 30
    el::Configurations defaultConf;
    defaultConf.setToDefault();
S
starlord 已提交
31
    defaultConf.set(el::Level::Debug, el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
H
Heisenberg 已提交
32 33
    el::Loggers::reconfigureLogger("default", defaultConf);
}
X
xj.lin 已提交
34

S
starlord 已提交
35 36
void
DataGen::Init_with_default() {
X
xj.lin 已提交
37 38 39
    Generate(dim, nb, nq);
}

S
starlord 已提交
40 41
void
DataGen::Generate(const int& dim, const int& nb, const int& nq) {
X
xj.lin 已提交
42 43 44 45 46
    this->nb = nb;
    this->nq = nq;
    this->dim = dim;

    GenAll(dim, nb, xb, ids, nq, xq);
S
starlord 已提交
47 48
    assert(xb.size() == (size_t)dim * nb);
    assert(xq.size() == (size_t)dim * nq);
X
xj.lin 已提交
49 50 51 52

    base_dataset = generate_dataset(nb, dim, xb.data(), ids.data());
    query_dataset = generate_query_dataset(nq, dim, xq.data());
}
S
starlord 已提交
53

S
starlord 已提交
54
knowhere::DatasetPtr
S
starlord 已提交
55
DataGen::GenQuery(const int& nq) {
X
xj.lin 已提交
56
    xq.resize(nq * dim);
S
starlord 已提交
57
    for (int i = 0; i < nq * dim; ++i) {
X
xj.lin 已提交
58 59 60 61 62
        xq[i] = xb[i];
    }
    return generate_query_dataset(nq, dim, xq.data());
}

S
starlord 已提交
63 64 65
void
GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector<int64_t>& ids, const int64_t& nq,
       std::vector<float>& xq) {
X
xj.lin 已提交
66 67 68 69 70 71
    xb.resize(nb * dim);
    xq.resize(nq * dim);
    ids.resize(nb);
    GenAll(dim, nb, xb.data(), ids.data(), nq, xq.data());
}

S
starlord 已提交
72 73
void
GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq) {
X
xj.lin 已提交
74
    GenBase(dim, nb, xb, ids);
S
starlord 已提交
75
    for (int64_t i = 0; i < nq * dim; ++i) {
X
xj.lin 已提交
76 77 78 79
        xq[i] = xb[i];
    }
}

S
starlord 已提交
80 81
void
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
X
xj.lin 已提交
82 83
    for (auto i = 0; i < nb; ++i) {
        for (auto j = 0; j < dim; ++j) {
S
starlord 已提交
84
            // p_data[i * d + j] = float(base + i);
X
xj.lin 已提交
85 86 87 88 89 90 91
            xb[i * dim + j] = drand48();
        }
        xb[dim * i] += i / 1000.;
        ids[i] = i;
    }
}

S
starlord 已提交
92
FileIOReader::FileIOReader(const std::string& fname) {
X
xj.lin 已提交
93 94 95 96 97 98 99 100
    name = fname;
    fs = std::fstream(name, std::ios::in | std::ios::binary);
}

FileIOReader::~FileIOReader() {
    fs.close();
}

S
starlord 已提交
101 102 103
size_t
FileIOReader::operator()(void* ptr, size_t size) {
    fs.read(reinterpret_cast<char*>(ptr), size);
X
xj.lin 已提交
104 105 106
    return size;
}

S
starlord 已提交
107
FileIOWriter::FileIOWriter(const std::string& fname) {
X
xj.lin 已提交
108 109 110 111 112 113 114 115
    name = fname;
    fs = std::fstream(name, std::ios::out | std::ios::binary);
}

FileIOWriter::~FileIOWriter() {
    fs.close();
}

S
starlord 已提交
116 117 118
size_t
FileIOWriter::operator()(void* ptr, size_t size) {
    fs.write(reinterpret_cast<char*>(ptr), size);
X
xj.lin 已提交
119 120 121
    return size;
}

S
starlord 已提交
122
knowhere::DatasetPtr
S
starlord 已提交
123
generate_dataset(int64_t nb, int64_t dim, float* xb, int64_t* ids) {
X
xj.lin 已提交
124
    std::vector<int64_t> shape{nb, dim};
S
starlord 已提交
125 126 127 128 129 130 131 132 133 134 135 136
    auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
    std::vector<knowhere::TensorPtr> tensors{tensor};
    std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields);

    auto id_array = knowhere::ConstructInt64Array((uint8_t*)ids, nb * sizeof(int64_t));
    std::vector<knowhere::ArrayPtr> arrays{id_array};
    std::vector<knowhere::FieldPtr> array_fields{knowhere::ConstructInt64Field("id")};
    auto array_schema = std::make_shared<knowhere::Schema>(tensor_fields);

    auto dataset =
        std::make_shared<knowhere::Dataset>(std::move(arrays), array_schema, std::move(tensors), tensor_schema);
X
xj.lin 已提交
137 138 139
    return dataset;
}

S
starlord 已提交
140
knowhere::DatasetPtr
S
starlord 已提交
141
generate_query_dataset(int64_t nb, int64_t dim, float* xb) {
X
xj.lin 已提交
142
    std::vector<int64_t> shape{nb, dim};
S
starlord 已提交
143 144 145 146
    auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
    std::vector<knowhere::TensorPtr> tensors{tensor};
    std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields);
X
xj.lin 已提交
147

S
starlord 已提交
148
    auto dataset = std::make_shared<knowhere::Dataset>(std::move(tensors), tensor_schema);
X
xj.lin 已提交
149 150
    return dataset;
}
X
xiaojun.lin 已提交
151 152 153

void
AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
154
    auto ids = result->ids();
X
xiaojun.lin 已提交
155
    for (auto i = 0; i < nq; i++) {
156 157
        EXPECT_EQ(i, *((int64_t*)(ids) + i * k));
        //        EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
X
xiaojun.lin 已提交
158 159 160 161 162
    }
}

void
PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
163 164
    auto ids = result->ids();
    auto dists = result->dist();
X
xiaojun.lin 已提交
165 166 167

    std::stringstream ss_id;
    std::stringstream ss_dist;
168
    for (auto i = 0; i < nq; i++) {
X
xiaojun.lin 已提交
169
        for (auto j = 0; j < k; ++j) {
X
xiaojun.lin 已提交
170 171
            // ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
            // ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
172 173
            ss_id << *((int64_t*)(ids) + i * k + j) << " ";
            ss_dist << *((float*)(dists) + i * k + j) << " ";
X
xiaojun.lin 已提交
174 175 176 177 178 179 180
        }
        ss_id << std::endl;
        ss_dist << std::endl;
    }
    std::cout << "id\n" << ss_id.str() << std::endl;
    std::cout << "dist\n" << ss_dist.str() << std::endl;
}