utils.cpp 4.6 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

18
#include "unittest/utils.h"
X
xj.lin 已提交
19

S
starlord 已提交
20 21 22
#include <memory>
#include <string>
#include <utility>
X
xj.lin 已提交
23

H
Heisenberg 已提交
24 25
INITIALIZE_EASYLOGGINGPP

S
starlord 已提交
26 27
namespace {

S
starlord 已提交
28
namespace kn = knowhere;
S
starlord 已提交
29 30 31 32 33

}  // namespace

void
InitLog() {
H
Heisenberg 已提交
34 35
    el::Configurations defaultConf;
    defaultConf.setToDefault();
S
starlord 已提交
36
    defaultConf.set(el::Level::Debug, el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)");
H
Heisenberg 已提交
37 38
    el::Loggers::reconfigureLogger("default", defaultConf);
}
X
xj.lin 已提交
39

S
starlord 已提交
40 41
void
DataGen::Init_with_default() {
X
xj.lin 已提交
42 43 44
    Generate(dim, nb, nq);
}

S
starlord 已提交
45 46
void
DataGen::Generate(const int& dim, const int& nb, const int& nq) {
X
xj.lin 已提交
47 48 49 50 51
    this->nb = nb;
    this->nq = nq;
    this->dim = dim;

    GenAll(dim, nb, xb, ids, nq, xq);
S
starlord 已提交
52 53
    assert(xb.size() == (size_t)dim * nb);
    assert(xq.size() == (size_t)dim * nq);
X
xj.lin 已提交
54 55 56 57

    base_dataset = generate_dataset(nb, dim, xb.data(), ids.data());
    query_dataset = generate_query_dataset(nq, dim, xq.data());
}
S
starlord 已提交
58

S
starlord 已提交
59
knowhere::DatasetPtr
S
starlord 已提交
60
DataGen::GenQuery(const int& nq) {
X
xj.lin 已提交
61
    xq.resize(nq * dim);
S
starlord 已提交
62
    for (int i = 0; i < nq * dim; ++i) {
X
xj.lin 已提交
63 64 65 66 67
        xq[i] = xb[i];
    }
    return generate_query_dataset(nq, dim, xq.data());
}

S
starlord 已提交
68 69 70
void
GenAll(const int64_t dim, const int64_t& nb, std::vector<float>& xb, std::vector<int64_t>& ids, const int64_t& nq,
       std::vector<float>& xq) {
X
xj.lin 已提交
71 72 73 74 75 76
    xb.resize(nb * dim);
    xq.resize(nq * dim);
    ids.resize(nb);
    GenAll(dim, nb, xb.data(), ids.data(), nq, xq.data());
}

S
starlord 已提交
77 78
void
GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq) {
X
xj.lin 已提交
79
    GenBase(dim, nb, xb, ids);
S
starlord 已提交
80
    for (int64_t i = 0; i < nq * dim; ++i) {
X
xj.lin 已提交
81 82 83 84
        xq[i] = xb[i];
    }
}

S
starlord 已提交
85 86
void
GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) {
X
xj.lin 已提交
87 88
    for (auto i = 0; i < nb; ++i) {
        for (auto j = 0; j < dim; ++j) {
S
starlord 已提交
89
            // p_data[i * d + j] = float(base + i);
X
xj.lin 已提交
90 91 92 93 94 95 96
            xb[i * dim + j] = drand48();
        }
        xb[dim * i] += i / 1000.;
        ids[i] = i;
    }
}

S
starlord 已提交
97
FileIOReader::FileIOReader(const std::string& fname) {
X
xj.lin 已提交
98 99 100 101 102 103 104 105
    name = fname;
    fs = std::fstream(name, std::ios::in | std::ios::binary);
}

FileIOReader::~FileIOReader() {
    fs.close();
}

S
starlord 已提交
106 107 108
size_t
FileIOReader::operator()(void* ptr, size_t size) {
    fs.read(reinterpret_cast<char*>(ptr), size);
X
xj.lin 已提交
109 110 111
    return size;
}

S
starlord 已提交
112
FileIOWriter::FileIOWriter(const std::string& fname) {
X
xj.lin 已提交
113 114 115 116 117 118 119 120
    name = fname;
    fs = std::fstream(name, std::ios::out | std::ios::binary);
}

FileIOWriter::~FileIOWriter() {
    fs.close();
}

S
starlord 已提交
121 122 123
size_t
FileIOWriter::operator()(void* ptr, size_t size) {
    fs.write(reinterpret_cast<char*>(ptr), size);
X
xj.lin 已提交
124 125 126
    return size;
}

S
starlord 已提交
127 128
kn::DatasetPtr
generate_dataset(int64_t nb, int64_t dim, float* xb, int64_t* ids) {
X
xj.lin 已提交
129
    std::vector<int64_t> shape{nb, dim};
S
starlord 已提交
130 131 132 133 134 135 136 137 138 139 140
    auto tensor = kn::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
    std::vector<kn::TensorPtr> tensors{tensor};
    std::vector<kn::FieldPtr> tensor_fields{kn::ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<kn::Schema>(tensor_fields);

    auto id_array = kn::ConstructInt64Array((uint8_t*)ids, nb * sizeof(int64_t));
    std::vector<kn::ArrayPtr> arrays{id_array};
    std::vector<kn::FieldPtr> array_fields{kn::ConstructInt64Field("id")};
    auto array_schema = std::make_shared<kn::Schema>(tensor_fields);

    auto dataset = std::make_shared<kn::Dataset>(std::move(arrays), array_schema, std::move(tensors), tensor_schema);
X
xj.lin 已提交
141 142 143
    return dataset;
}

S
starlord 已提交
144 145
kn::DatasetPtr
generate_query_dataset(int64_t nb, int64_t dim, float* xb) {
X
xj.lin 已提交
146
    std::vector<int64_t> shape{nb, dim};
S
starlord 已提交
147 148 149 150
    auto tensor = kn::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
    std::vector<kn::TensorPtr> tensors{tensor};
    std::vector<kn::FieldPtr> tensor_fields{kn::ConstructFloatField("data")};
    auto tensor_schema = std::make_shared<kn::Schema>(tensor_fields);
X
xj.lin 已提交
151

S
starlord 已提交
152
    auto dataset = std::make_shared<kn::Dataset>(std::move(tensors), tensor_schema);
X
xj.lin 已提交
153 154
    return dataset;
}