VecServiceTask.cpp 24.3 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7 8 9 10
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
#include "VecServiceTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
G
groot 已提交
11
#include "utils/TimeRecorder.h"
G
groot 已提交
12
#include "utils/ThreadPool.h"
G
groot 已提交
13 14
#include "db/DB.h"
#include "db/Env.h"
G
groot 已提交
15
#include "db/Meta.h"
G
groot 已提交
16

G
groot 已提交
17

G
groot 已提交
18 19 20 21
namespace zilliz {
namespace vecwise {
namespace server {

G
groot 已提交
22 23
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
G
groot 已提交
24

G
groot 已提交
25
static const std::string VECTOR_UID = "uid";
G
groot 已提交
26
static const uint64_t USE_MT = 5000;
G
groot 已提交
27

G
groot 已提交
28 29 30
using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;

G
groot 已提交
31 32 33 34 35
namespace {
    class DBWrapper {
    public:
        DBWrapper() {
            zilliz::vecwise::engine::Options opt;
G
groot 已提交
36 37 38 39
            ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
            opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
            std::string db_path = config.GetValue(CONFIG_DB_PATH);
            opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
G
groot 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
            opt.meta.path = db_path + "/db";

            CommonUtil::CreateDirectory(opt.meta.path);

            zilliz::vecwise::engine::DB::Open(opt, &db_);
            if(db_ == nullptr) {
                SERVER_LOG_ERROR << "Failed to open db";
                throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
            }
        }

        zilliz::vecwise::engine::DB* DB() { return db_; }

    private:
        zilliz::vecwise::engine::DB* db_ = nullptr;
    };

    zilliz::vecwise::engine::DB* DB() {
        static DBWrapper db_wrapper;
        return db_wrapper.DB();
    }
G
groot 已提交
61 62 63 64 65 66

    DB_DATE MakeDbDate(const VecDateTime& dt) {
        time_t  t_t;
        CommonUtil::ConvertTime(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, t_t);
        return DB_META::GetDate(t_t);
    }
G
groot 已提交
67 68 69 70 71

    ThreadPool& GetThreadPool() {
        static ThreadPool pool(6);
        return pool;
    }
G
groot 已提交
72 73 74 75 76
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddGroupTask::AddGroupTask(int32_t dimension,
                           const std::string& group_id)
G
groot 已提交
77
: BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  dimension_(dimension),
  group_id_(group_id) {

}

BaseTaskPtr AddGroupTask::Create(int32_t dimension,
                                 const std::string& group_id) {
    return std::shared_ptr<BaseTask>(new AddGroupTask(dimension,group_id));
}

ServerError AddGroupTask::OnExecute() {
    try {
        engine::meta::GroupSchema group_info;
        group_info.dimension = (size_t)dimension_;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->add_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
95 96 97 98
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
99 100 101
        }

    } catch (std::exception& ex) {
G
groot 已提交
102 103 104
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
105
        return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
106
    }
G
groot 已提交
107 108

    return SERVER_SUCCESS;
G
groot 已提交
109 110 111 112
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t&  dimension)
G
groot 已提交
113
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
      group_id_(group_id),
      dimension_(dimension) {

}

BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t&  dimension) {
    return std::shared_ptr<BaseTask>(new GetGroupTask(group_id, dimension));
}

ServerError GetGroupTask::OnExecute() {
    try {
        dimension_ = 0;

        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
131 132 133 134
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
135 136 137 138 139
        } else {
            dimension_ = (int32_t)group_info.dimension;
        }

    } catch (std::exception& ex) {
G
groot 已提交
140 141 142
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
143
        return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
144
    }
G
groot 已提交
145 146

    return SERVER_SUCCESS;
G
groot 已提交
147 148 149 150
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
G
groot 已提交
151
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
152 153 154 155 156 157 158 159 160
      group_id_(group_id) {

}

BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) {
    return std::shared_ptr<BaseTask>(new DeleteGroupTask(group_id));
}

ServerError DeleteGroupTask::OnExecute() {
G
groot 已提交
161 162 163 164
    error_code_ = SERVER_NOT_IMPLEMENT;
    error_msg_ = "delete group not implemented";
    SERVER_LOG_ERROR << error_msg_;
    return SERVER_NOT_IMPLEMENT;
G
groot 已提交
165 166 167
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
G
groot 已提交
168
AddVectorTask::AddVectorTask(const std::string& group_id,
G
groot 已提交
169 170
                             const VecTensor* tensor,
                             std::string& id)
G
groot 已提交
171
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
172
      group_id_(group_id),
G
groot 已提交
173
      tensor_(tensor),
G
groot 已提交
174 175
      bin_tensor_(nullptr),
      tensor_id_(id) {
G
groot 已提交
176 177 178

}

G
groot 已提交
179
BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
G
groot 已提交
180 181 182
                                  const VecTensor* tensor,
                                  std::string& id) {
    return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
G
groot 已提交
183 184
}

G
groot 已提交
185
AddVectorTask::AddVectorTask(const std::string& group_id,
G
groot 已提交
186 187
                             const VecBinaryTensor* tensor,
                             std::string& id)
G
groot 已提交
188 189 190
        : BaseTask(DDL_DML_TASK_GROUP),
          group_id_(group_id),
          tensor_(nullptr),
G
groot 已提交
191 192
          bin_tensor_(tensor),
          tensor_id_(id) {
G
groot 已提交
193 194 195 196

}

BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
G
groot 已提交
197 198 199
                                  const VecBinaryTensor* tensor,
                                  std::string& id) {
    return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
G
groot 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
}


uint64_t AddVectorTask::GetVecDimension() const {
    if(tensor_) {
        return (uint64_t) tensor_->tensor.size();
    } else if(bin_tensor_) {
        return (uint64_t) bin_tensor_->tensor.size()/8;
    } else {
        return 0;
    }
}

const double* AddVectorTask::GetVecData() const {
    if(tensor_) {
        return (const double*)(tensor_->tensor.data());
    } else if(bin_tensor_) {
        return (const double*)(bin_tensor_->tensor.data());
    } else {
        return nullptr;
    }

}

std::string AddVectorTask::GetVecID() const {
    if(tensor_) {
        return tensor_->uid;
    } else if(bin_tensor_) {
        return bin_tensor_->uid;
    } else {
        return "";
    }
}

G
groot 已提交
234 235 236 237 238 239 240 241
const AttribMap& AddVectorTask::GetVecAttrib() const {
    if(tensor_) {
        return tensor_->attrib;
    } else {
        return bin_tensor_->attrib;
    }
}

G
groot 已提交
242
ServerError AddVectorTask::OnExecute() {
G
groot 已提交
243
    try {
G
groot 已提交
244 245 246 247
        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
248 249 250 251
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
252 253
        }

G
groot 已提交
254 255 256 257 258
        uint64_t group_dim = group_info.dimension;
        uint64_t vec_dim = GetVecDimension();
        if(group_dim != vec_dim) {
            SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
                             << " vs. group dimension:" << group_dim;
G
groot 已提交
259 260 261 262
            error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
263 264 265 266
        }

        std::vector<float> vec_f;
        vec_f.resize(vec_dim);
G
groot 已提交
267
        const double* d_p = GetVecData();
G
groot 已提交
268
        for(uint64_t d = 0; d < vec_dim; d++) {
G
groot 已提交
269
            vec_f[d] = (float)(d_p[d]);
G
groot 已提交
270 271
        }

G
groot 已提交
272
        engine::IDNumbers vector_ids;
G
groot 已提交
273
        stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids);
G
groot 已提交
274
        if(!stat.ok()) {
G
groot 已提交
275 276 277 278
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
279 280 281
        } else {
            if(vector_ids.empty()) {
                SERVER_LOG_ERROR << "Vector ID not returned";
G
groot 已提交
282
                return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
283
            } else {
G
groot 已提交
284
                std::string uid = GetVecID();
G
groot 已提交
285 286 287 288 289 290 291 292
                std::string num_id = std::to_string(vector_ids[0]);
                if(uid.empty()) {
                    tensor_id_ = num_id;
                } else {
                    tensor_id_ = uid;
                }

                std::string nid = group_id_ + "_" + num_id;
G
groot 已提交
293
                AttribMap attrib = GetVecAttrib();
G
groot 已提交
294
                attrib[VECTOR_UID] = tensor_id_;
G
groot 已提交
295 296 297 298
                std::string attrib_str;
                AttributeSerializer::Encode(attrib, attrib_str);
                IVecIdMapper::GetInstance()->Put(nid, attrib_str);
                SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", uid = " << uid;
G
groot 已提交
299 300 301 302
            }
        }

    } catch (std::exception& ex) {
G
groot 已提交
303 304 305 306
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
G
groot 已提交
307
    }
G
groot 已提交
308 309

    return SERVER_SUCCESS;
G
groot 已提交
310 311 312 313 314
}


////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
G
groot 已提交
315 316
                                       const VecTensorList* tensor_list,
                                       std::vector<std::string>& ids)
G
groot 已提交
317
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
318
      group_id_(group_id),
G
groot 已提交
319
      tensor_list_(tensor_list),
G
groot 已提交
320 321 322 323
      bin_tensor_list_(nullptr),
      tensor_ids_(ids) {
    tensor_ids_.clear();
    tensor_ids_.resize(tensor_list->tensor_list.size());
G
groot 已提交
324 325 326
}

BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
G
groot 已提交
327 328 329
                                       const VecTensorList* tensor_list,
                                       std::vector<std::string>& ids) {
    return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
G
groot 已提交
330 331 332
}

AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
G
groot 已提交
333 334 335 336 337 338 339 340
                                       const VecBinaryTensorList* tensor_list,
                                       std::vector<std::string>& ids)
    : BaseTask(DDL_DML_TASK_GROUP),
      group_id_(group_id),
      tensor_list_(nullptr),
      bin_tensor_list_(tensor_list),
      tensor_ids_(ids) {
    tensor_ids_.clear();
G
groot 已提交
341 342
}

G
groot 已提交
343
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
G
groot 已提交
344 345 346
                                       const VecBinaryTensorList* tensor_list,
                                       std::vector<std::string>& ids) {
    return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
G
groot 已提交
347 348
}

G
groot 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
uint64_t AddBatchVectorTask::GetVecListCount() const {
    if(tensor_list_) {
        return (uint64_t) tensor_list_->tensor_list.size();
    } else if(bin_tensor_list_) {
        return (uint64_t) bin_tensor_list_->tensor_list.size();
    } else {
        return 0;
    }
}

uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const {
    if(tensor_list_) {
        if(index >= tensor_list_->tensor_list.size()){
            return 0;
        }
        return (uint64_t) tensor_list_->tensor_list[index].tensor.size();
    } else if(bin_tensor_list_) {
        if(index >= bin_tensor_list_->tensor_list.size()){
            return 0;
        }
G
groot 已提交
369
        return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8;
G
groot 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
    } else {
        return 0;
    }
}

const double* AddBatchVectorTask::GetVecData(uint64_t index) const {
    if(tensor_list_) {
        if(index >= tensor_list_->tensor_list.size()){
            return nullptr;
        }
        return tensor_list_->tensor_list[index].tensor.data();
    } else if(bin_tensor_list_) {
        if(index >= bin_tensor_list_->tensor_list.size()){
            return nullptr;
        }
        return (const double*)bin_tensor_list_->tensor_list[index].tensor.data();
    } else {
        return nullptr;
    }
}

std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
    if(tensor_list_) {
        if(index >= tensor_list_->tensor_list.size()){
            return 0;
        }
        return tensor_list_->tensor_list[index].uid;
    } else if(bin_tensor_list_) {
        if(index >= bin_tensor_list_->tensor_list.size()){
            return 0;
        }
        return bin_tensor_list_->tensor_list[index].uid;
    } else {
        return "";
    }
}

G
groot 已提交
407 408 409 410 411 412 413 414
const AttribMap& AddBatchVectorTask::GetVecAttrib(uint64_t index) const {
    if(tensor_list_) {
        return tensor_list_->tensor_list[index].attrib;
    } else {
        return bin_tensor_list_->tensor_list[index].attrib;
    }
}

G
groot 已提交
415 416 417
void AddBatchVectorTask::ProcessIdMapping(engine::IDNumbers& vector_ids,
                                          uint64_t from, uint64_t to,
                                          std::vector<std::string>& tensor_ids) {
G
groot 已提交
418 419 420
    std::string nid_prefix = group_id_ + "_";
    for(size_t i = from; i < to; i++) {
        std::string uid = GetVecID(i);
G
groot 已提交
421 422 423 424 425 426 427
        std::string num_id = std::to_string(vector_ids[i]);
        if(uid.empty()) {
            uid = num_id;
        }
        tensor_ids_[i] = uid;

        std::string nid = nid_prefix + num_id;
G
groot 已提交
428 429 430 431 432 433 434 435
        AttribMap attrib = GetVecAttrib(i);
        attrib[VECTOR_UID] = uid;
        std::string attrib_str;
        AttributeSerializer::Encode(attrib, attrib_str);
        IVecIdMapper::GetInstance()->Put(nid, attrib_str);
    }
}

G
groot 已提交
436
ServerError AddBatchVectorTask::OnExecute() {
G
groot 已提交
437
    try {
G
add log  
groot 已提交
438 439
        TimeRecorder rc("AddBatchVectorTask");

G
fix bug  
groot 已提交
440 441 442 443 444
        uint64_t vec_count = GetVecListCount();
        if(vec_count == 0) {
            return SERVER_SUCCESS;
        }

G
groot 已提交
445 446 447 448
        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
449 450 451 452
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
453 454
        }

G
add log  
groot 已提交
455 456
        rc.Record("check group dimension");

G
groot 已提交
457
        uint64_t group_dim = group_info.dimension;
G
groot 已提交
458
        std::vector<float> vec_f;
G
groot 已提交
459
        vec_f.resize(vec_count*group_dim);//allocate enough memory
G
groot 已提交
460
        for(uint64_t i = 0; i < vec_count; i ++) {
G
groot 已提交
461 462 463 464
            uint64_t vec_dim = GetVecDimension(i);
            if(vec_dim != group_dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
                                 << " vs. group dimension:" << group_dim;
G
groot 已提交
465 466 467 468
                error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
                error_msg_ = "Engine failed: " + stat.ToString();
                SERVER_LOG_ERROR << error_msg_;
                return error_code_;
G
groot 已提交
469
            }
G
groot 已提交
470

G
groot 已提交
471
            const double* d_p = GetVecData(i);
G
groot 已提交
472
            for(uint64_t d = 0; d < vec_dim; d++) {
G
groot 已提交
473
                vec_f[i*vec_dim + d] = (float)(d_p[d]);
G
groot 已提交
474
            }
G
groot 已提交
475
        }
G
groot 已提交
476

G
groot 已提交
477
        rc.Record("prepare vectors data");
G
groot 已提交
478 479

        engine::IDNumbers vector_ids;
G
groot 已提交
480
        stat = DB()->add_vectors(group_id_, vec_count, vec_f.data(), vector_ids);
G
groot 已提交
481
        rc.Record("add vectors to engine");
G
groot 已提交
482
        if(!stat.ok()) {
G
groot 已提交
483 484 485 486 487 488 489 490 491
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

        if(vector_ids.size() < vec_count) {
            SERVER_LOG_ERROR << "Vector ID not returned";
            return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
492
        } else {
G
groot 已提交
493 494 495 496
            tensor_ids_.resize(vector_ids.size());
            if(vec_count < USE_MT) {
                ProcessIdMapping(vector_ids, 0, vec_count, tensor_ids_);
                rc.Record("built id mapping");
G
groot 已提交
497
            } else {
G
groot 已提交
498 499 500 501 502 503 504 505 506 507 508
                std::list<std::future<void>> threads_list;

                uint64_t begin_index = 0, end_index = USE_MT;
                while(end_index < vec_count) {
                    threads_list.push_back(
                            GetThreadPool().enqueue(&AddBatchVectorTask::ProcessIdMapping,
                                               this, vector_ids, begin_index, end_index, tensor_ids_));
                    begin_index = end_index;
                    end_index += USE_MT;
                    if(end_index > vec_count) {
                        end_index = vec_count;
G
groot 已提交
509
                    }
G
groot 已提交
510
                }
G
groot 已提交
511

G
groot 已提交
512 513
                for (std::list<std::future<void>>::iterator it = threads_list.begin(); it != threads_list.end(); it++) {
                    it->wait();
G
groot 已提交
514
                }
G
groot 已提交
515 516

                rc.Record("built id mapping by multi-threads:" + std::to_string(threads_list.size()));
G
groot 已提交
517 518 519 520
            }
        }

    } catch (std::exception& ex) {
G
groot 已提交
521 522 523 524
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
G
groot 已提交
525
    }
G
groot 已提交
526 527

    return SERVER_SUCCESS;
G
groot 已提交
528 529 530
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
G
groot 已提交
531
SearchVectorTask::SearchVectorTask(const std::string& group_id,
G
groot 已提交
532
                                   const int64_t top_k,
G
groot 已提交
533
                                   const VecTensorList* tensor_list,
G
groot 已提交
534
                                   const VecSearchFilter& filter,
G
groot 已提交
535
                                   VecSearchResultList& result)
G
groot 已提交
536
    : BaseTask(DQL_TASK_GROUP),
G
groot 已提交
537 538 539
      group_id_(group_id),
      top_k_(top_k),
      tensor_list_(tensor_list),
G
groot 已提交
540
      bin_tensor_list_(nullptr),
G
groot 已提交
541
      filter_(filter),
G
groot 已提交
542 543 544 545 546 547 548
      result_(result) {

}

SearchVectorTask::SearchVectorTask(const std::string& group_id,
                                   const int64_t top_k,
                                   const VecBinaryTensorList* bin_tensor_list,
G
groot 已提交
549
                                   const VecSearchFilter& filter,
G
groot 已提交
550 551 552 553 554 555
                                   VecSearchResultList& result)
    : BaseTask(DQL_TASK_GROUP),
      group_id_(group_id),
      top_k_(top_k),
      tensor_list_(nullptr),
      bin_tensor_list_(bin_tensor_list),
G
groot 已提交
556
      filter_(filter),
G
groot 已提交
557
      result_(result) {
G
groot 已提交
558 559 560

}

G
groot 已提交
561
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
G
groot 已提交
562
                                     const int64_t top_k,
G
groot 已提交
563
                                     const VecTensorList* tensor_list,
G
groot 已提交
564
                                     const VecSearchFilter& filter,
G
groot 已提交
565
                                     VecSearchResultList& result) {
G
groot 已提交
566
    return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, filter, result));
G
groot 已提交
567 568
}

G
groot 已提交
569 570 571
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
                                     const int64_t top_k,
                                     const VecBinaryTensorList* bin_tensor_list,
G
groot 已提交
572
                                     const VecSearchFilter& filter,
G
groot 已提交
573
                                     VecSearchResultList& result) {
G
groot 已提交
574
    return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, bin_tensor_list, filter, result));
G
groot 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
}


ServerError SearchVectorTask::GetTargetData(std::vector<float>& data) const {
    if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
        uint64_t count = tensor_list_->tensor_list.size();
        uint64_t dim = tensor_list_->tensor_list[0].tensor.size();
        data.resize(count*dim);
        for(size_t i = 0; i < count; i++) {
            if(tensor_list_->tensor_list[i].tensor.size() != dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size();
                return SERVER_INVALID_ARGUMENT;
            }
            const double* d_p = tensor_list_->tensor_list[i].tensor.data();
            for(int64_t k = 0; k < dim; k++) {
                data[i*dim + k] = (float)(d_p[k]);
            }
        }
    } else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
        uint64_t count = bin_tensor_list_->tensor_list.size();
        uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8;
        data.resize(count*dim);
        for(size_t i = 0; i < count; i++) {
            if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8;
                return SERVER_INVALID_ARGUMENT;
            }
            const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data());
            for(int64_t k = 0; k < dim; k++) {
                data[i*dim + k] = (float)(d_p[k]);
            }
        }
    }

    return SERVER_SUCCESS;
}

uint64_t SearchVectorTask::GetTargetDimension() const {
    if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
        return tensor_list_->tensor_list[0].tensor.size();
    } else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
        return bin_tensor_list_->tensor_list[0].tensor.size()/8;
    }

    return 0;
}

uint64_t SearchVectorTask::GetTargetCount() const {
    if(tensor_list_) {
        return tensor_list_->tensor_list.size();
    } else if(bin_tensor_list_) {
        return bin_tensor_list_->tensor_list.size();
    }
}

G
groot 已提交
630 631
ServerError SearchVectorTask::OnExecute() {
    try {
G
add log  
groot 已提交
632 633
        TimeRecorder rc("SearchVectorTask");

G
groot 已提交
634 635 636 637
        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
638 639 640 641
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
642 643 644 645 646 647 648 649 650
        }

        uint64_t vec_dim = GetTargetDimension();
        if(vec_dim != group_info.dimension) {
            SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
                             << " vs. group dimension:" << group_info.dimension;
            return SERVER_INVALID_ARGUMENT;
        }

G
add log  
groot 已提交
651 652
        rc.Record("check group dimension");

G
groot 已提交
653
        std::vector<float> vec_f;
G
groot 已提交
654 655 656
        ServerError err = GetTargetData(vec_f);
        if(err != SERVER_SUCCESS) {
            return err;
G
groot 已提交
657 658
        }

G
groot 已提交
659 660
        uint64_t vec_count = GetTargetCount();

G
groot 已提交
661 662 663 664 665 666
        std::vector<DB_DATE> dates;
        for(const VecTimeRange& tr : filter_.time_ranges) {
            dates.push_back(MakeDbDate(tr.time_begin));
            dates.push_back(MakeDbDate(tr.time_end));
        }

G
add log  
groot 已提交
667 668
        rc.Record("prepare input data");

G
groot 已提交
669
        engine::QueryResults results;
G
groot 已提交
670
        stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), dates, results);
G
groot 已提交
671 672
        if(!stat.ok()) {
            SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
G
groot 已提交
673
            return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
674
        } else {
G
add log  
groot 已提交
675
            rc.Record("do search");
G
groot 已提交
676 677 678
            for(engine::QueryResult& res : results){
                VecSearchResult v_res;
                std::string nid_prefix = group_id_ + "_";
G
groot 已提交
679
                for(auto id : res) {
G
groot 已提交
680
                    std::string attrib_str;
G
groot 已提交
681
                    std::string nid = nid_prefix + std::to_string(id);
G
groot 已提交
682 683 684 685 686
                    IVecIdMapper::GetInstance()->Get(nid, attrib_str);

                    AttribMap attrib_map;
                    AttributeSerializer::Decode(attrib_str, attrib_map);

G
groot 已提交
687
                    VecSearchResultItem item;
G
groot 已提交
688 689
                    item.__set_attrib(attrib_map);
                    item.uid = item.attrib[VECTOR_UID];
G
groot 已提交
690 691
                    item.distance = 0.0;////TODO: return distance
                    v_res.result_list.emplace_back(item);
G
groot 已提交
692

G
groot 已提交
693
                    SERVER_LOG_TRACE << "nid = " << nid << ", uid = " << item.uid;
G
groot 已提交
694 695 696 697
                }

                result_.result_list.push_back(v_res);
            }
G
add log  
groot 已提交
698
            rc.Record("construct result");
G
groot 已提交
699 700 701
        }

    } catch (std::exception& ex) {
G
groot 已提交
702 703 704 705
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
G
groot 已提交
706
    }
G
groot 已提交
707 708

    return SERVER_SUCCESS;
G
groot 已提交
709 710 711 712 713
}

}
}
}