提交 b43b9bac 编写于 作者: Y yudong.cai

#89 do normalize() for IP test


Former-commit-id: 11be41ee43e3dd0e2ecbcd50ba70bab2df58688d
上级 2f762900
......@@ -54,6 +54,20 @@ elapsed() {
return tv.tv_sec + tv.tv_usec * 1e-6;
}
void normalize(float* arr, size_t nq, size_t dim) {
for (size_t i = 0; i < nq; i++) {
double vecLen = 0.0;
for (size_t j = 0; j < dim; j++) {
double val = arr[i * dim + j];
vecLen += val * val;
}
vecLen = std::sqrt(vecLen);
for (size_t j = 0; j < dim; j++) {
arr[i * dim + j] = (float) (arr[i * dim + j] / vecLen);
}
}
}
void*
hdf5_read(const char* file_name, const char* dataset_name, H5T_class_t dataset_class, size_t& d_out, size_t& n_out) {
hid_t file, dataset, datatype, dataspace, memspace;
......@@ -237,6 +251,11 @@ test_ann_hdf5(const std::string& ann_test_name, const std::string& index_key, in
float* xb = (float*)hdf5_read(ann_file_name.c_str(), "train", H5T_FLOAT, d, nb);
assert(d == dim || !"dataset does not have correct dimension");
if (metric_type == faiss::METRIC_INNER_PRODUCT) {
printf("[%.3f s] Normalizing data set \n", elapsed() - t0);
normalize(xb, nb, d);
}
printf("[%.3f s] Preparing index \"%s\" d=%ld\n", elapsed() - t0, index_key.c_str(), d);
index = faiss::index_factory(d, index_key.c_str(), metric_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册