VecServiceTask.cpp 24.5 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7 8 9 10
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
#include "VecServiceTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
G
groot 已提交
11
#include "utils/TimeRecorder.h"
G
groot 已提交
12
#include "utils/ThreadPool.h"
G
groot 已提交
13 14
#include "db/DB.h"
#include "db/Env.h"
G
groot 已提交
15
#include "db/Meta.h"
G
groot 已提交
16

G
groot 已提交
17

G
groot 已提交
18 19 20 21
namespace zilliz {
namespace vecwise {
namespace server {

G
groot 已提交
22 23
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
G
groot 已提交
24

G
groot 已提交
25
static const std::string VECTOR_UID = "uid";
G
groot 已提交
26
static const uint64_t USE_MT = 5000;
G
groot 已提交
27

G
groot 已提交
28 29 30
using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;

G
groot 已提交
31 32 33 34 35
namespace {
    class DBWrapper {
    public:
        DBWrapper() {
            zilliz::vecwise::engine::Options opt;
G
groot 已提交
36 37 38 39
            ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
            opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
            std::string db_path = config.GetValue(CONFIG_DB_PATH);
            opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
G
groot 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
            opt.meta.path = db_path + "/db";

            CommonUtil::CreateDirectory(opt.meta.path);

            zilliz::vecwise::engine::DB::Open(opt, &db_);
            if(db_ == nullptr) {
                SERVER_LOG_ERROR << "Failed to open db";
                throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
            }
        }

        zilliz::vecwise::engine::DB* DB() { return db_; }

    private:
        zilliz::vecwise::engine::DB* db_ = nullptr;
    };

    zilliz::vecwise::engine::DB* DB() {
        static DBWrapper db_wrapper;
        return db_wrapper.DB();
    }
G
groot 已提交
61 62 63 64 65 66

    DB_DATE MakeDbDate(const VecDateTime& dt) {
        time_t  t_t;
        CommonUtil::ConvertTime(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, t_t);
        return DB_META::GetDate(t_t);
    }
G
groot 已提交
67 68 69 70 71

    ThreadPool& GetThreadPool() {
        static ThreadPool pool(6);
        return pool;
    }
G
groot 已提交
72 73 74 75 76
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddGroupTask::AddGroupTask(int32_t dimension,
                           const std::string& group_id)
G
groot 已提交
77
: BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
78 79 80 81 82 83 84 85 86 87 88 89
  dimension_(dimension),
  group_id_(group_id) {

}

BaseTaskPtr AddGroupTask::Create(int32_t dimension,
                                 const std::string& group_id) {
    return std::shared_ptr<BaseTask>(new AddGroupTask(dimension,group_id));
}

ServerError AddGroupTask::OnExecute() {
    try {
G
groot 已提交
90
        IVecIdMapper::GetInstance()->AddGroup(group_id_);
G
groot 已提交
91 92 93 94
        engine::meta::GroupSchema group_info;
        group_info.dimension = (size_t)dimension_;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->add_group(group_info);
G
groot 已提交
95 96
        if(!stat.ok()) {//could exist
            SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
G
groot 已提交
97
            SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
98
            return SERVER_SUCCESS;
G
groot 已提交
99 100 101
        }

    } catch (std::exception& ex) {
G
groot 已提交
102 103 104
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
105
        return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
106
    }
G
groot 已提交
107 108

    return SERVER_SUCCESS;
G
groot 已提交
109 110 111 112
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t&  dimension)
G
groot 已提交
113
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
      group_id_(group_id),
      dimension_(dimension) {

}

BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t&  dimension) {
    return std::shared_ptr<BaseTask>(new GetGroupTask(group_id, dimension));
}

ServerError GetGroupTask::OnExecute() {
    try {
        dimension_ = 0;

        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
131 132 133 134
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
135 136 137 138 139
        } else {
            dimension_ = (int32_t)group_info.dimension;
        }

    } catch (std::exception& ex) {
G
groot 已提交
140 141 142
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
143
        return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
144
    }
G
groot 已提交
145 146

    return SERVER_SUCCESS;
G
groot 已提交
147 148 149 150
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
G
groot 已提交
151
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
152 153 154 155 156 157 158 159 160
      group_id_(group_id) {

}

BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) {
    return std::shared_ptr<BaseTask>(new DeleteGroupTask(group_id));
}

ServerError DeleteGroupTask::OnExecute() {
G
groot 已提交
161 162 163
    error_code_ = SERVER_NOT_IMPLEMENT;
    error_msg_ = "delete group not implemented";
    SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
164 165 166

    //IVecIdMapper::GetInstance()->DeleteGroup(group_id_);

G
groot 已提交
167
    return SERVER_NOT_IMPLEMENT;
G
groot 已提交
168 169 170
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
G
groot 已提交
171
AddVectorTask::AddVectorTask(const std::string& group_id,
G
groot 已提交
172 173
                             const VecTensor* tensor,
                             std::string& id)
G
groot 已提交
174
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
175
      group_id_(group_id),
G
groot 已提交
176
      tensor_(tensor),
G
groot 已提交
177 178
      bin_tensor_(nullptr),
      tensor_id_(id) {
G
groot 已提交
179 180 181

}

G
groot 已提交
182
BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
G
groot 已提交
183 184 185
                                  const VecTensor* tensor,
                                  std::string& id) {
    return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
G
groot 已提交
186 187
}

G
groot 已提交
188
AddVectorTask::AddVectorTask(const std::string& group_id,
G
groot 已提交
189 190
                             const VecBinaryTensor* tensor,
                             std::string& id)
G
groot 已提交
191 192 193
        : BaseTask(DDL_DML_TASK_GROUP),
          group_id_(group_id),
          tensor_(nullptr),
G
groot 已提交
194 195
          bin_tensor_(tensor),
          tensor_id_(id) {
G
groot 已提交
196 197 198 199

}

BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
G
groot 已提交
200 201 202
                                  const VecBinaryTensor* tensor,
                                  std::string& id) {
    return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
G
groot 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
}


uint64_t AddVectorTask::GetVecDimension() const {
    if(tensor_) {
        return (uint64_t) tensor_->tensor.size();
    } else if(bin_tensor_) {
        return (uint64_t) bin_tensor_->tensor.size()/8;
    } else {
        return 0;
    }
}

const double* AddVectorTask::GetVecData() const {
    if(tensor_) {
        return (const double*)(tensor_->tensor.data());
    } else if(bin_tensor_) {
        return (const double*)(bin_tensor_->tensor.data());
    } else {
        return nullptr;
    }

}

std::string AddVectorTask::GetVecID() const {
    if(tensor_) {
        return tensor_->uid;
    } else if(bin_tensor_) {
        return bin_tensor_->uid;
    } else {
        return "";
    }
}

G
groot 已提交
237 238 239 240 241 242 243 244
const AttribMap& AddVectorTask::GetVecAttrib() const {
    if(tensor_) {
        return tensor_->attrib;
    } else {
        return bin_tensor_->attrib;
    }
}

G
groot 已提交
245
ServerError AddVectorTask::OnExecute() {
G
groot 已提交
246
    try {
G
groot 已提交
247 248 249 250 251 252 253
        if(!IVecIdMapper::GetInstance()->IsGroupExist(group_id_)) {
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "group not exist";
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

G
groot 已提交
254
        uint64_t vec_dim = GetVecDimension();
G
groot 已提交
255 256
        std::vector<float> vec_f;
        vec_f.resize(vec_dim);
G
groot 已提交
257
        const double* d_p = GetVecData();
G
groot 已提交
258
        for(uint64_t d = 0; d < vec_dim; d++) {
G
groot 已提交
259
            vec_f[d] = (float)(d_p[d]);
G
groot 已提交
260 261
        }

G
groot 已提交
262
        engine::IDNumbers vector_ids;
G
groot 已提交
263
        engine::Status stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids);
G
groot 已提交
264
        if(!stat.ok()) {
G
groot 已提交
265 266 267 268
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
269 270
        } else {
            if(vector_ids.empty()) {
G
groot 已提交
271 272
                error_msg_ = "Engine failed: " + stat.ToString();
                SERVER_LOG_ERROR << error_msg_;
G
groot 已提交
273
                return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
274
            } else {
G
groot 已提交
275
                std::string uid = GetVecID();
G
groot 已提交
276 277 278 279 280 281 282 283
                std::string num_id = std::to_string(vector_ids[0]);
                if(uid.empty()) {
                    tensor_id_ = num_id;
                } else {
                    tensor_id_ = uid;
                }

                std::string nid = group_id_ + "_" + num_id;
G
groot 已提交
284
                AttribMap attrib = GetVecAttrib();
G
groot 已提交
285
                attrib[VECTOR_UID] = tensor_id_;
G
groot 已提交
286 287
                std::string attrib_str;
                AttributeSerializer::Encode(attrib, attrib_str);
G
groot 已提交
288 289
                IVecIdMapper::GetInstance()->Put(nid, attrib_str, group_id_);
                //SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", uid = " << uid;
G
groot 已提交
290 291 292 293
            }
        }

    } catch (std::exception& ex) {
G
groot 已提交
294 295 296 297
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
G
groot 已提交
298
    }
G
groot 已提交
299 300

    return SERVER_SUCCESS;
G
groot 已提交
301 302 303 304 305
}


////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
G
groot 已提交
306 307
                                       const VecTensorList* tensor_list,
                                       std::vector<std::string>& ids)
G
groot 已提交
308
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
309
      group_id_(group_id),
G
groot 已提交
310
      tensor_list_(tensor_list),
G
groot 已提交
311 312 313 314
      bin_tensor_list_(nullptr),
      tensor_ids_(ids) {
    tensor_ids_.clear();
    tensor_ids_.resize(tensor_list->tensor_list.size());
G
groot 已提交
315 316 317
}

BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
G
groot 已提交
318 319 320
                                       const VecTensorList* tensor_list,
                                       std::vector<std::string>& ids) {
    return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
G
groot 已提交
321 322 323
}

AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
G
groot 已提交
324 325 326 327 328 329 330 331
                                       const VecBinaryTensorList* tensor_list,
                                       std::vector<std::string>& ids)
    : BaseTask(DDL_DML_TASK_GROUP),
      group_id_(group_id),
      tensor_list_(nullptr),
      bin_tensor_list_(tensor_list),
      tensor_ids_(ids) {
    tensor_ids_.clear();
G
groot 已提交
332 333
}

G
groot 已提交
334
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
G
groot 已提交
335 336 337
                                       const VecBinaryTensorList* tensor_list,
                                       std::vector<std::string>& ids) {
    return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
G
groot 已提交
338 339
}

G
groot 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
uint64_t AddBatchVectorTask::GetVecListCount() const {
    if(tensor_list_) {
        return (uint64_t) tensor_list_->tensor_list.size();
    } else if(bin_tensor_list_) {
        return (uint64_t) bin_tensor_list_->tensor_list.size();
    } else {
        return 0;
    }
}

uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const {
    if(tensor_list_) {
        if(index >= tensor_list_->tensor_list.size()){
            return 0;
        }
        return (uint64_t) tensor_list_->tensor_list[index].tensor.size();
    } else if(bin_tensor_list_) {
        if(index >= bin_tensor_list_->tensor_list.size()){
            return 0;
        }
G
groot 已提交
360
        return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8;
G
groot 已提交
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
    } else {
        return 0;
    }
}

const double* AddBatchVectorTask::GetVecData(uint64_t index) const {
    if(tensor_list_) {
        if(index >= tensor_list_->tensor_list.size()){
            return nullptr;
        }
        return tensor_list_->tensor_list[index].tensor.data();
    } else if(bin_tensor_list_) {
        if(index >= bin_tensor_list_->tensor_list.size()){
            return nullptr;
        }
        return (const double*)bin_tensor_list_->tensor_list[index].tensor.data();
    } else {
        return nullptr;
    }
}

std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
    if(tensor_list_) {
        if(index >= tensor_list_->tensor_list.size()){
            return 0;
        }
        return tensor_list_->tensor_list[index].uid;
    } else if(bin_tensor_list_) {
        if(index >= bin_tensor_list_->tensor_list.size()){
            return 0;
        }
        return bin_tensor_list_->tensor_list[index].uid;
    } else {
        return "";
    }
}

G
groot 已提交
398 399 400 401 402 403 404 405
const AttribMap& AddBatchVectorTask::GetVecAttrib(uint64_t index) const {
    if(tensor_list_) {
        return tensor_list_->tensor_list[index].attrib;
    } else {
        return bin_tensor_list_->tensor_list[index].attrib;
    }
}

G
groot 已提交
406 407 408
void AddBatchVectorTask::ProcessIdMapping(engine::IDNumbers& vector_ids,
                                          uint64_t from, uint64_t to,
                                          std::vector<std::string>& tensor_ids) {
G
groot 已提交
409 410 411
    std::string nid_prefix = group_id_ + "_";
    for(size_t i = from; i < to; i++) {
        std::string uid = GetVecID(i);
G
groot 已提交
412 413 414 415 416 417 418
        std::string num_id = std::to_string(vector_ids[i]);
        if(uid.empty()) {
            uid = num_id;
        }
        tensor_ids_[i] = uid;

        std::string nid = nid_prefix + num_id;
G
groot 已提交
419 420 421 422
        AttribMap attrib = GetVecAttrib(i);
        attrib[VECTOR_UID] = uid;
        std::string attrib_str;
        AttributeSerializer::Encode(attrib, attrib_str);
G
groot 已提交
423
        IVecIdMapper::GetInstance()->Put(nid, attrib_str, group_id_);
G
groot 已提交
424 425 426
    }
}

G
groot 已提交
427
ServerError AddBatchVectorTask::OnExecute() {
G
groot 已提交
428
    try {
G
add log  
groot 已提交
429 430
        TimeRecorder rc("AddBatchVectorTask");

G
fix bug  
groot 已提交
431 432 433 434 435
        uint64_t vec_count = GetVecListCount();
        if(vec_count == 0) {
            return SERVER_SUCCESS;
        }

G
groot 已提交
436 437 438 439
        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
440 441 442 443
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
444 445
        }

G
add log  
groot 已提交
446 447
        rc.Record("check group dimension");

G
groot 已提交
448
        uint64_t group_dim = group_info.dimension;
G
groot 已提交
449
        std::vector<float> vec_f;
G
groot 已提交
450
        vec_f.resize(vec_count*group_dim);//allocate enough memory
G
groot 已提交
451
        for(uint64_t i = 0; i < vec_count; i ++) {
G
groot 已提交
452 453 454 455
            uint64_t vec_dim = GetVecDimension(i);
            if(vec_dim != group_dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
                                 << " vs. group dimension:" << group_dim;
G
groot 已提交
456 457 458
                error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
                error_msg_ = "Engine failed: " + stat.ToString();
                return error_code_;
G
groot 已提交
459
            }
G
groot 已提交
460

G
groot 已提交
461
            const double* d_p = GetVecData(i);
G
groot 已提交
462
            for(uint64_t d = 0; d < vec_dim; d++) {
G
groot 已提交
463
                vec_f[i*vec_dim + d] = (float)(d_p[d]);
G
groot 已提交
464
            }
G
groot 已提交
465
        }
G
groot 已提交
466

G
groot 已提交
467
        rc.Record("prepare vectors data");
G
groot 已提交
468 469

        engine::IDNumbers vector_ids;
G
groot 已提交
470
        stat = DB()->add_vectors(group_id_, vec_count, vec_f.data(), vector_ids);
G
groot 已提交
471
        rc.Record("add vectors to engine");
G
groot 已提交
472
        if(!stat.ok()) {
G
groot 已提交
473 474 475 476 477 478 479 480 481
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

        if(vector_ids.size() < vec_count) {
            SERVER_LOG_ERROR << "Vector ID not returned";
            return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
482
        } else {
G
groot 已提交
483 484 485 486
            tensor_ids_.resize(vector_ids.size());
            if(vec_count < USE_MT) {
                ProcessIdMapping(vector_ids, 0, vec_count, tensor_ids_);
                rc.Record("built id mapping");
G
groot 已提交
487
            } else {
G
groot 已提交
488 489 490 491 492 493 494 495 496 497 498
                std::list<std::future<void>> threads_list;

                uint64_t begin_index = 0, end_index = USE_MT;
                while(end_index < vec_count) {
                    threads_list.push_back(
                            GetThreadPool().enqueue(&AddBatchVectorTask::ProcessIdMapping,
                                               this, vector_ids, begin_index, end_index, tensor_ids_));
                    begin_index = end_index;
                    end_index += USE_MT;
                    if(end_index > vec_count) {
                        end_index = vec_count;
G
groot 已提交
499
                    }
G
groot 已提交
500
                }
G
groot 已提交
501

G
groot 已提交
502 503
                for (std::list<std::future<void>>::iterator it = threads_list.begin(); it != threads_list.end(); it++) {
                    it->wait();
G
groot 已提交
504
                }
G
groot 已提交
505 506

                rc.Record("built id mapping by multi-threads:" + std::to_string(threads_list.size()));
G
groot 已提交
507 508 509 510
            }
        }

    } catch (std::exception& ex) {
G
groot 已提交
511 512 513 514
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
G
groot 已提交
515
    }
G
groot 已提交
516 517

    return SERVER_SUCCESS;
G
groot 已提交
518 519 520
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
G
groot 已提交
521
SearchVectorTask::SearchVectorTask(const std::string& group_id,
G
groot 已提交
522
                                   const int64_t top_k,
G
groot 已提交
523
                                   const VecTensorList* tensor_list,
G
groot 已提交
524
                                   const VecSearchFilter& filter,
G
groot 已提交
525
                                   VecSearchResultList& result)
G
groot 已提交
526
    : BaseTask(DQL_TASK_GROUP),
G
groot 已提交
527 528 529
      group_id_(group_id),
      top_k_(top_k),
      tensor_list_(tensor_list),
G
groot 已提交
530
      bin_tensor_list_(nullptr),
G
groot 已提交
531
      filter_(filter),
G
groot 已提交
532 533 534 535 536 537 538
      result_(result) {

}

SearchVectorTask::SearchVectorTask(const std::string& group_id,
                                   const int64_t top_k,
                                   const VecBinaryTensorList* bin_tensor_list,
G
groot 已提交
539
                                   const VecSearchFilter& filter,
G
groot 已提交
540 541 542 543 544 545
                                   VecSearchResultList& result)
    : BaseTask(DQL_TASK_GROUP),
      group_id_(group_id),
      top_k_(top_k),
      tensor_list_(nullptr),
      bin_tensor_list_(bin_tensor_list),
G
groot 已提交
546
      filter_(filter),
G
groot 已提交
547
      result_(result) {
G
groot 已提交
548 549 550

}

G
groot 已提交
551
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
G
groot 已提交
552
                                     const int64_t top_k,
G
groot 已提交
553
                                     const VecTensorList* tensor_list,
G
groot 已提交
554
                                     const VecSearchFilter& filter,
G
groot 已提交
555
                                     VecSearchResultList& result) {
G
groot 已提交
556
    return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, filter, result));
G
groot 已提交
557 558
}

G
groot 已提交
559 560 561
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
                                     const int64_t top_k,
                                     const VecBinaryTensorList* bin_tensor_list,
G
groot 已提交
562
                                     const VecSearchFilter& filter,
G
groot 已提交
563
                                     VecSearchResultList& result) {
G
groot 已提交
564
    return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, bin_tensor_list, filter, result));
G
groot 已提交
565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
}


ServerError SearchVectorTask::GetTargetData(std::vector<float>& data) const {
    if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
        uint64_t count = tensor_list_->tensor_list.size();
        uint64_t dim = tensor_list_->tensor_list[0].tensor.size();
        data.resize(count*dim);
        for(size_t i = 0; i < count; i++) {
            if(tensor_list_->tensor_list[i].tensor.size() != dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size();
                return SERVER_INVALID_ARGUMENT;
            }
            const double* d_p = tensor_list_->tensor_list[i].tensor.data();
            for(int64_t k = 0; k < dim; k++) {
                data[i*dim + k] = (float)(d_p[k]);
            }
        }
    } else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
        uint64_t count = bin_tensor_list_->tensor_list.size();
        uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8;
        data.resize(count*dim);
        for(size_t i = 0; i < count; i++) {
            if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8;
                return SERVER_INVALID_ARGUMENT;
            }
            const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data());
            for(int64_t k = 0; k < dim; k++) {
                data[i*dim + k] = (float)(d_p[k]);
            }
        }
    }

    return SERVER_SUCCESS;
}

uint64_t SearchVectorTask::GetTargetDimension() const {
    if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
        return tensor_list_->tensor_list[0].tensor.size();
    } else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
        return bin_tensor_list_->tensor_list[0].tensor.size()/8;
    }

    return 0;
}

uint64_t SearchVectorTask::GetTargetCount() const {
    if(tensor_list_) {
        return tensor_list_->tensor_list.size();
    } else if(bin_tensor_list_) {
        return bin_tensor_list_->tensor_list.size();
    }
}

G
groot 已提交
620 621
ServerError SearchVectorTask::OnExecute() {
    try {
G
add log  
groot 已提交
622 623
        TimeRecorder rc("SearchVectorTask");

G
groot 已提交
624 625 626 627
        engine::meta::GroupSchema group_info;
        group_info.group_id = group_id_;
        engine::Status stat = DB()->get_group(group_info);
        if(!stat.ok()) {
G
groot 已提交
628 629 630 631
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
G
groot 已提交
632 633 634 635 636 637
        }

        uint64_t vec_dim = GetTargetDimension();
        if(vec_dim != group_info.dimension) {
            SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
                             << " vs. group dimension:" << group_info.dimension;
G
groot 已提交
638 639 640
            error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
            error_msg_ = "Engine failed: " + stat.ToString();
            return error_code_;
G
groot 已提交
641 642
        }

G
add log  
groot 已提交
643 644
        rc.Record("check group dimension");

G
groot 已提交
645
        std::vector<float> vec_f;
G
groot 已提交
646 647 648
        ServerError err = GetTargetData(vec_f);
        if(err != SERVER_SUCCESS) {
            return err;
G
groot 已提交
649 650
        }

G
groot 已提交
651 652
        uint64_t vec_count = GetTargetCount();

G
groot 已提交
653 654 655 656 657 658
        std::vector<DB_DATE> dates;
        for(const VecTimeRange& tr : filter_.time_ranges) {
            dates.push_back(MakeDbDate(tr.time_begin));
            dates.push_back(MakeDbDate(tr.time_end));
        }

G
add log  
groot 已提交
659 660
        rc.Record("prepare input data");

G
groot 已提交
661
        engine::QueryResults results;
G
groot 已提交
662
        stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), dates, results);
G
groot 已提交
663 664
        if(!stat.ok()) {
            SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
G
groot 已提交
665
            return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
666
        } else {
G
add log  
groot 已提交
667
            rc.Record("do search");
G
groot 已提交
668 669 670
            for(engine::QueryResult& res : results){
                VecSearchResult v_res;
                std::string nid_prefix = group_id_ + "_";
G
groot 已提交
671
                for(auto id : res) {
G
groot 已提交
672
                    std::string attrib_str;
G
groot 已提交
673
                    std::string nid = nid_prefix + std::to_string(id);
G
groot 已提交
674
                    IVecIdMapper::GetInstance()->Get(nid, attrib_str, group_id_);
G
groot 已提交
675 676 677 678

                    AttribMap attrib_map;
                    AttributeSerializer::Decode(attrib_str, attrib_map);

G
groot 已提交
679
                    AttribMap attrib_return;
G
groot 已提交
680
                    VecSearchResultItem item;
G
groot 已提交
681 682 683 684 685 686 687 688 689 690 691 692 693
                    item.uid = attrib_map[VECTOR_UID];

                    if(filter_.return_attribs.empty()) {//return all attributes
                        attrib_return.swap(attrib_map);
                    } else {//filter attributes
                        for(auto& name : filter_.return_attribs) {
                            if(attrib_map.count(name) == 0)
                                continue;

                            attrib_return[name] = attrib_map[name];
                        }
                    }
                    item.__set_attrib(attrib_return);
G
groot 已提交
694 695
                    item.distance = 0.0;////TODO: return distance
                    v_res.result_list.emplace_back(item);
G
groot 已提交
696

G
groot 已提交
697
                    //SERVER_LOG_TRACE << "nid = " << nid << ", uid = " << item.uid;
G
groot 已提交
698 699 700 701
                }

                result_.result_list.push_back(v_res);
            }
G
add log  
groot 已提交
702
            rc.Record("construct result");
G
groot 已提交
703 704 705
        }

    } catch (std::exception& ex) {
G
groot 已提交
706 707 708 709
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
G
groot 已提交
710
    }
G
groot 已提交
711 712

    return SERVER_SUCCESS;
G
groot 已提交
713 714 715 716 717
}

}
}
}