RequestTask.cpp 26.3 KB
Newer Older
G
groot 已提交
1 2 3 4 5
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
G
groot 已提交
6
#include "RequestTask.h"
G
groot 已提交
7 8 9 10
#include "ServerConfig.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
J
jinhai 已提交
11
#include "utils/ValidationUtil.h"
G
groot 已提交
12
#include "DBWrapper.h"
G
groot 已提交
13
#include "version.h"
G
groot 已提交
14 15

namespace zilliz {
J
jinhai 已提交
16
namespace milvus {
G
groot 已提交
17 18
namespace server {

G
groot 已提交
19 20
using namespace ::milvus;

G
groot 已提交
21 22
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
G
groot 已提交
23
static const std::string PING_TASK_GROUP = "ping";
G
groot 已提交
24

J
jinhai 已提交
25 26
using DB_META = zilliz::milvus::engine::meta::Meta;
using DB_DATE = zilliz::milvus::engine::meta::DateT;
G
groot 已提交
27 28

namespace {
G
groot 已提交
29 30 31 32 33
    engine::EngineType EngineType(int type) {
        static std::map<int, engine::EngineType> map_type = {
                {0, engine::EngineType::INVALID},
                {1, engine::EngineType::FAISS_IDMAP},
                {2, engine::EngineType::FAISS_IVFFLAT},
G
groot 已提交
34
                {3, engine::EngineType::FAISS_IVFSQ8},
G
groot 已提交
35 36 37 38 39 40 41
        };

        if(map_type.find(type) == map_type.end()) {
            return engine::EngineType::INVALID;
        }

        return map_type[type];
G
groot 已提交
42
    }
G
groot 已提交
43

G
groot 已提交
44 45 46 47 48
    int IndexType(engine::EngineType type) {
        static std::map<engine::EngineType, int> map_type = {
                {engine::EngineType::INVALID, 0},
                {engine::EngineType::FAISS_IDMAP, 1},
                {engine::EngineType::FAISS_IVFFLAT, 2},
G
groot 已提交
49
                {engine::EngineType::FAISS_IVFSQ8, 3},
G
groot 已提交
50 51 52 53 54 55 56 57 58
        };

        if(map_type.find(type) == map_type.end()) {
            return 0;
        }

        return map_type[type];
    }

G
groot 已提交
59
    void
G
groot 已提交
60 61
    ConvertRowRecordToFloatArray(const std::vector<thrift::RowRecord>& record_array,
                                 uint64_t dimension,
G
groot 已提交
62 63 64
                                 std::vector<float>& float_array,
                                 ServerError& error_code,
                                 std::string& error_msg) {
G
groot 已提交
65 66 67 68 69
        uint64_t vec_count = record_array.size();
        float_array.resize(vec_count*dimension);//allocate enough memory
        for(uint64_t i = 0; i < vec_count; i++) {
            const auto& record = record_array[i];
            if(record.vector_data.empty()) {
G
groot 已提交
70 71 72
                error_code = SERVER_INVALID_ROWRECORD;
                error_msg = "Rowrecord float array is empty";
                return;
G
groot 已提交
73 74 75 76
            }
            uint64_t vec_dim = record.vector_data.size()/sizeof(double);//how many double value?
            if(vec_dim != dimension) {
                error_code = SERVER_INVALID_VECTOR_DIMENSION;
G
groot 已提交
77 78 79
                error_msg = "Invalid rowrecord dimension: " + std::to_string(vec_dim)
                                 + " vs. table dimension:" + std::to_string(dimension);
                return;
G
groot 已提交
80 81 82 83 84 85 86 87 88 89 90 91
            }

            //convert double array to float array(thrift has no float type)
            const double* d_p = reinterpret_cast<const double*>(record.vector_data.data());
            for(uint64_t d = 0; d < vec_dim; d++) {
                float_array[i*vec_dim + d] = (float)(d_p[d]);
            }
        }
    }

    static constexpr long DAY_SECONDS = 86400;

G
groot 已提交
92
    void
G
groot 已提交
93
    ConvertTimeRangeToDBDates(const std::vector<thrift::Range> &range_array,
G
groot 已提交
94 95 96
                              std::vector<DB_DATE>& dates,
                              ServerError& error_code,
                              std::string& error_msg) {
G
groot 已提交
97 98 99 100 101 102
        dates.clear();
        for(auto& range : range_array) {
            time_t tt_start, tt_end;
            tm tm_start, tm_end;
            if(!CommonUtil::TimeStrToTime(range.start_value, tt_start, tm_start)){
                error_code = SERVER_INVALID_TIME_RANGE;
G
groot 已提交
103 104
                error_msg = "Invalid time range: " + range.start_value;
                return;
G
groot 已提交
105 106 107 108
            }

            if(!CommonUtil::TimeStrToTime(range.end_value, tt_end, tm_end)){
                error_code = SERVER_INVALID_TIME_RANGE;
G
groot 已提交
109 110
                error_msg = "Invalid time range: " + range.start_value;
                return;
G
groot 已提交
111 112 113
            }

            long days = (tt_end > tt_start) ? (tt_end - tt_start)/DAY_SECONDS : (tt_start - tt_end)/DAY_SECONDS;
G
groot 已提交
114 115 116 117 118 119 120
            if(days == 0) {
                error_code = SERVER_INVALID_TIME_RANGE;
                error_msg = "Invalid time range: " + range.start_value + " to " + range.end_value;
                return ;
            }

            for(long i = 0; i < days; i++) {
G
groot 已提交
121 122 123 124 125 126 127 128 129
                time_t tt_day = tt_start + DAY_SECONDS*i;
                tm tm_day;
                CommonUtil::ConvertTime(tt_day, tm_day);

                long date = tm_day.tm_year*10000 + tm_day.tm_mon*100 + tm_day.tm_mday;//according to db logic
                dates.push_back(date);
            }
        }
    }
G
groot 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const thrift::TableSchema& schema)
: BaseTask(DDL_DML_TASK_GROUP),
  schema_(schema) {

}

BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
    return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
}

ServerError CreateTableTask::OnExecute() {
    TimeRecorder rc("CreateTableTask");
P
peng.xu 已提交
145

G
groot 已提交
146
    try {
G
groot 已提交
147
        //step 1: check arguments
J
jinhai 已提交
148 149 150
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(schema_.table_name);
        if(res != SERVER_SUCCESS) {
151
            return SetError(res, "Invalid table name: " + schema_.table_name);
G
groot 已提交
152
        }
J
jinhai 已提交
153 154 155

        res = ValidateTableDimension(schema_.dimension);
        if(res != SERVER_SUCCESS) {
156
            return SetError(res, "Invalid table dimension: " + std::to_string(schema_.dimension));
G
groot 已提交
157 158
        }

J
jinhai 已提交
159 160
        res = ValidateTableIndexType(schema_.index_type);
        if(res != SERVER_SUCCESS) {
161
            return SetError(res, "Invalid index type: " + std::to_string(schema_.index_type));
G
groot 已提交
162 163 164
        }

        //step 2: construct table schema
G
groot 已提交
165
        engine::meta::TableSchema table_info;
G
groot 已提交
166 167 168 169 170
        table_info.dimension_ = (uint16_t)schema_.dimension;
        table_info.table_id_ = schema_.table_name;
        table_info.engine_type_ = (int)EngineType(schema_.index_type);
        table_info.store_raw_data_ = schema_.store_raw_vector;

G
groot 已提交
171
        //step 3: create table
G
groot 已提交
172
        engine::Status stat = DBWrapper::DB()->CreateTable(table_info);
G
groot 已提交
173
        if(!stat.ok()) {//table could exist
G
groot 已提交
174
            return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
G
groot 已提交
175 176 177
        }

    } catch (std::exception& ex) {
G
groot 已提交
178
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
G
groot 已提交
179 180 181 182 183 184 185 186 187
    }

    rc.Record("done");

    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
G
groot 已提交
188
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201
      table_name_(table_name),
      schema_(schema) {
    schema_.table_name = table_name_;
}

BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, thrift::TableSchema& schema) {
    return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
}

ServerError DescribeTableTask::OnExecute() {
    TimeRecorder rc("DescribeTableTask");

    try {
G
groot 已提交
202
        //step 1: check arguments
J
jinhai 已提交
203 204 205
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
206
            return SetError(res, "Invalid table name: " + table_name_);
G
groot 已提交
207 208 209
        }

        //step 2: get table info
G
groot 已提交
210
        engine::meta::TableSchema table_info;
G
groot 已提交
211
        table_info.table_id_ = table_name_;
G
groot 已提交
212
        engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
G
groot 已提交
213
        if(!stat.ok()) {
G
groot 已提交
214
            return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
G
groot 已提交
215 216
        }

G
groot 已提交
217 218 219 220 221
        schema_.table_name = table_info.table_id_;
        schema_.index_type = IndexType((engine::EngineType)table_info.engine_type_);
        schema_.dimension = table_info.dimension_;
        schema_.store_raw_vector = table_info.store_raw_data_;

G
groot 已提交
222
    } catch (std::exception& ex) {
G
groot 已提交
223
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
G
groot 已提交
224 225 226 227 228 229 230
    }

    rc.Record("done");

    return SERVER_SUCCESS;
}

P
peng.xu 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BuildIndexTask::BuildIndexTask(const std::string& table_name)
    : BaseTask(DDL_DML_TASK_GROUP),
      table_name_(table_name) {
}

BaseTaskPtr BuildIndexTask::Create(const std::string& table_name) {
    return std::shared_ptr<BaseTask>(new BuildIndexTask(table_name));
}

ServerError BuildIndexTask::OnExecute() {
    try {
        TimeRecorder rc("BuildIndexTask");

        //step 1: check arguments
246 247 248 249 250 251 252 253 254 255
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
            return SetError(res, "Invalid table name: " + table_name_);
        }

        bool has_table = false;
        engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table);
        if(!has_table) {
            return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
P
peng.xu 已提交
256 257 258
        }

        //step 2: check table existence
259
        stat = DBWrapper::DB()->BuildIndex(table_name_);
P
peng.xu 已提交
260 261 262 263 264 265 266 267 268 269 270 271
        if(!stat.ok()) {
            return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString());
        }

        rc.Elapse("totally cost");
    } catch (std::exception& ex) {
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
    }

    return SERVER_SUCCESS;
}

G
groot 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
HasTableTask::HasTableTask(const std::string& table_name, bool& has_table)
    : BaseTask(DDL_DML_TASK_GROUP),
      table_name_(table_name),
      has_table_(has_table) {

}

BaseTaskPtr HasTableTask::Create(const std::string& table_name, bool& has_table) {
    return std::shared_ptr<BaseTask>(new HasTableTask(table_name, has_table));
}

ServerError HasTableTask::OnExecute() {
    try {
        TimeRecorder rc("HasTableTask");

        //step 1: check arguments
J
jinhai 已提交
289 290 291
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
292
            return SetError(res, "Invalid table name: " + table_name_);
G
groot 已提交
293
        }
294

G
groot 已提交
295 296
        //step 2: check table existence
        engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table_);
G
groot 已提交
297 298 299
        if(!stat.ok()) {
            return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
        }
G
groot 已提交
300 301 302

        rc.Elapse("totally cost");
    } catch (std::exception& ex) {
G
groot 已提交
303
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
G
groot 已提交
304 305 306 307 308
    }

    return SERVER_SUCCESS;
}

G
groot 已提交
309 310 311 312 313 314 315
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTableTask::DeleteTableTask(const std::string& table_name)
    : BaseTask(DDL_DML_TASK_GROUP),
      table_name_(table_name) {

}

G
groot 已提交
316 317
BaseTaskPtr DeleteTableTask::Create(const std::string& table_name) {
    return std::shared_ptr<BaseTask>(new DeleteTableTask(table_name));
G
groot 已提交
318 319 320
}

ServerError DeleteTableTask::OnExecute() {
G
groot 已提交
321 322 323
    try {
        TimeRecorder rc("DeleteTableTask");

G
groot 已提交
324
        //step 1: check arguments
J
jinhai 已提交
325 326 327
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
328
            return SetError(res, "Invalid table name: " + table_name_);
G
groot 已提交
329 330 331 332 333
        }

        //step 2: check table existence
        engine::meta::TableSchema table_info;
        table_info.table_id_ = table_name_;
G
groot 已提交
334
        engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
G
groot 已提交
335
        if(!stat.ok()) {
G
groot 已提交
336 337 338 339 340
            if(stat.IsNotFound()) {
                return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
            } else {
                return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
            }
G
groot 已提交
341 342 343
        }

        rc.Record("check validation");
G
groot 已提交
344

G
groot 已提交
345 346
        //step 3: delete table
        std::vector<DB_DATE> dates;
G
groot 已提交
347
        stat = DBWrapper::DB()->DeleteTable(table_name_, dates);
G
groot 已提交
348
        if(!stat.ok()) {
G
groot 已提交
349
            return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
G
groot 已提交
350 351 352
        }

        rc.Record("deleta table");
Z
fix  
zhiru 已提交
353
        rc.Elapse("total cost");
G
groot 已提交
354
    } catch (std::exception& ex) {
G
groot 已提交
355
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
G
groot 已提交
356 357 358
    }

    return SERVER_SUCCESS;
G
groot 已提交
359 360
}

G
groot 已提交
361 362
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ShowTablesTask::ShowTablesTask(std::vector<std::string>& tables)
G
groot 已提交
363
    : BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
364 365 366 367 368 369 370 371 372
      tables_(tables) {

}

BaseTaskPtr ShowTablesTask::Create(std::vector<std::string>& tables) {
    return std::shared_ptr<BaseTask>(new ShowTablesTask(tables));
}

ServerError ShowTablesTask::OnExecute() {
G
groot 已提交
373
    std::vector<engine::meta::TableSchema> schema_array;
G
groot 已提交
374
    engine::Status stat = DBWrapper::DB()->AllTables(schema_array);
G
groot 已提交
375
    if(!stat.ok()) {
G
groot 已提交
376
        return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
G
groot 已提交
377 378 379 380 381 382
    }

    tables_.clear();
    for(auto& schema : schema_array) {
        tables_.push_back(schema.table_id_);
    }
G
groot 已提交
383 384 385

    return SERVER_SUCCESS;
}
G
groot 已提交
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
                                       const std::vector<thrift::RowRecord>& record_array,
                                       std::vector<int64_t>& record_ids)
    : BaseTask(DDL_DML_TASK_GROUP),
      table_name_(table_name),
      record_array_(record_array),
      record_ids_(record_ids) {
    record_ids_.clear();
}

BaseTaskPtr AddVectorTask::Create(const std::string& table_name,
                                       const std::vector<thrift::RowRecord>& record_array,
                                       std::vector<int64_t>& record_ids) {
    return std::shared_ptr<BaseTask>(new AddVectorTask(table_name, record_array, record_ids));
}

ServerError AddVectorTask::OnExecute() {
    try {
        TimeRecorder rc("AddVectorTask");

G
groot 已提交
408
        //step 1: check arguments
J
jinhai 已提交
409 410 411
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
412
            return SetError(res, "Invalid table name: " + table_name_);
G
groot 已提交
413 414
        }

G
groot 已提交
415
        if(record_array_.empty()) {
G
groot 已提交
416
            return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
G
groot 已提交
417 418
        }

G
groot 已提交
419
        //step 2: check table existence
G
groot 已提交
420
        engine::meta::TableSchema table_info;
G
groot 已提交
421
        table_info.table_id_ = table_name_;
G
groot 已提交
422
        engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
G
groot 已提交
423
        if(!stat.ok()) {
G
groot 已提交
424 425 426 427 428
            if(stat.IsNotFound()) {
                return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
            } else {
                return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
            }
G
groot 已提交
429 430
        }

G
groot 已提交
431
        rc.Record("check validation");
G
groot 已提交
432

G
groot 已提交
433
        //step 3: prepare float data
G
groot 已提交
434
        std::vector<float> vec_f;
G
groot 已提交
435 436 437 438 439
        ServerError error_code = SERVER_SUCCESS;
        std::string error_msg;
        ConvertRowRecordToFloatArray(record_array_, table_info.dimension_, vec_f, error_code, error_msg);
        if(error_code != SERVER_SUCCESS) {
            return SetError(error_code, error_msg);
G
groot 已提交
440 441 442 443
        }

        rc.Record("prepare vectors data");

G
groot 已提交
444
        //step 4: insert vectors
G
groot 已提交
445
        uint64_t vec_count = (uint64_t)record_array_.size();
G
groot 已提交
446
        stat = DBWrapper::DB()->InsertVectors(table_name_, vec_count, vec_f.data(), record_ids_);
G
groot 已提交
447 448
        rc.Record("add vectors to engine");
        if(!stat.ok()) {
G
groot 已提交
449
            return SetError(SERVER_CACHE_ERROR, "Cache error: " + stat.ToString());
G
groot 已提交
450 451
        }

G
groot 已提交
452
        if(record_ids_.size() != vec_count) {
G
groot 已提交
453 454 455
            std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return "
                    + std::to_string(record_ids_.size()) + " id";
            return SetError(SERVER_ILLEGAL_VECTOR_ID, msg);
G
groot 已提交
456 457
        }

G
groot 已提交
458
        rc.Record("do insert");
Z
fix  
zhiru 已提交
459
        rc.Elapse("total cost");
G
groot 已提交
460 461

    } catch (std::exception& ex) {
G
groot 已提交
462
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
G
groot 已提交
463 464 465 466 467 468
    }

    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
469 470 471 472 473
SearchVectorTaskBase::SearchVectorTaskBase(const std::string &table_name,
        const std::vector<std::string>& file_id_array,
        const std::vector<thrift::RowRecord> &query_record_array,
        const std::vector<thrift::Range> &query_range_array,
        const int64_t top_k)
G
groot 已提交
474 475 476 477 478
    : BaseTask(DQL_TASK_GROUP),
      table_name_(table_name),
      file_id_array_(file_id_array),
      record_array_(query_record_array),
      range_array_(query_range_array),
479
      top_k_(top_k) {
G
groot 已提交
480 481 482

}

483
ServerError SearchVectorTaskBase::OnExecute() {
G
groot 已提交
484 485 486
    try {
        TimeRecorder rc("SearchVectorTask");

G
groot 已提交
487
        //step 1: check arguments
J
jinhai 已提交
488 489 490
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
491
            return SetError(res, "Invalid table name: " + table_name_);
G
groot 已提交
492 493
        }

G
groot 已提交
494 495 496 497 498
        if(top_k_ <= 0) {
            return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_));
        }
        if(record_array_.empty()) {
            return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
G
groot 已提交
499 500
        }

G
groot 已提交
501
        //step 2: check table existence
G
groot 已提交
502
        engine::meta::TableSchema table_info;
G
groot 已提交
503
        table_info.table_id_ = table_name_;
G
groot 已提交
504
        engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
G
groot 已提交
505
        if(!stat.ok()) {
G
groot 已提交
506 507 508 509 510
            if(stat.IsNotFound()) {
                return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
            } else {
                return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
            }
G
groot 已提交
511 512
        }

G
groot 已提交
513 514
        //step 3: check date range, and convert to db dates
        std::vector<DB_DATE> dates;
G
groot 已提交
515 516 517 518 519
        ServerError error_code = SERVER_SUCCESS;
        std::string error_msg;
        ConvertTimeRangeToDBDates(range_array_, dates, error_code, error_msg);
        if(error_code != SERVER_SUCCESS) {
            return SetError(error_code, error_msg);
G
groot 已提交
520 521
        }

G
groot 已提交
522 523 524
        rc.Record("check validation");

        //step 3: prepare float data
G
groot 已提交
525
        std::vector<float> vec_f;
G
groot 已提交
526 527 528
        ConvertRowRecordToFloatArray(record_array_, table_info.dimension_, vec_f, error_code, error_msg);
        if(error_code != SERVER_SUCCESS) {
            return SetError(error_code, error_msg);
G
groot 已提交
529 530 531 532
        }

        rc.Record("prepare vector data");

G
groot 已提交
533
        //step 4: search vectors
G
groot 已提交
534
        engine::QueryResults results;
G
groot 已提交
535
        uint64_t record_count = (uint64_t)record_array_.size();
536 537

        if(file_id_array_.empty()) {
G
groot 已提交
538
            stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, vec_f.data(), dates, results);
539
        } else {
G
groot 已提交
540
            stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k_, record_count, vec_f.data(), dates, results);
541 542
        }

G
groot 已提交
543
        rc.Record("search vectors from engine");
G
groot 已提交
544
        if(!stat.ok()) {
G
groot 已提交
545 546 547 548 549
            return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
        }

        if(results.empty()) {
            return SERVER_SUCCESS; //empty table
G
groot 已提交
550 551 552
        }

        if(results.size() != record_count) {
G
groot 已提交
553 554 555
            std::string msg = "Search " + std::to_string(record_count) + " vectors but only return "
                              + std::to_string(results.size()) + " results";
            return SetError(SERVER_ILLEGAL_SEARCH_RESULT, msg);
G
groot 已提交
556 557
        }

G
groot 已提交
558 559 560
        rc.Record("do search");

        //step 5: construct result array
561 562 563
        ConstructResult(results);
        rc.Record("construct result");
        rc.Elapse("total cost");
G
groot 已提交
564

565 566 567
    } catch (std::exception& ex) {
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
    }
G
groot 已提交
568

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask1::SearchVectorTask1(const std::string &table_name,
                                     const std::vector<std::string>& file_id_array,
                                     const std::vector<thrift::RowRecord> &query_record_array,
                                     const std::vector<thrift::Range> &query_range_array,
                                     const int64_t top_k,
                                     std::vector<thrift::TopKQueryResult> &result_array)
        : SearchVectorTaskBase(table_name, file_id_array, query_record_array, query_range_array, top_k),
          result_array_(result_array) {

}

BaseTaskPtr SearchVectorTask1::Create(const std::string& table_name,
                                      const std::vector<std::string>& file_id_array,
                                      const std::vector<thrift::RowRecord> & query_record_array,
                                      const std::vector<thrift::Range> & query_range_array,
                                      const int64_t top_k,
                                      std::vector<thrift::TopKQueryResult>& result_array) {
    return std::shared_ptr<BaseTask>(new SearchVectorTask1(table_name, file_id_array,
                                                           query_record_array, query_range_array, top_k, result_array));
}

ServerError SearchVectorTask1::ConstructResult(engine::QueryResults& results) {
    for(uint64_t i = 0; i < results.size(); i++) {
        auto& result = results[i];
        const auto& record = record_array_[i];

        thrift::TopKQueryResult thrift_topk_result;
        for(auto& pair : result) {
            thrift::QueryResult thrift_result;
            thrift_result.__set_id(pair.first);
            thrift_result.__set_distance(pair.second);

            thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
        }
G
groot 已提交
607

608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
        result_array_.emplace_back(thrift_topk_result);
    }

    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask2::SearchVectorTask2(const std::string &table_name,
                                     const std::vector<std::string>& file_id_array,
                                     const std::vector<thrift::RowRecord> &query_record_array,
                                     const std::vector<thrift::Range> &query_range_array,
                                     const int64_t top_k,
                                     std::vector<thrift::TopKQueryBinResult> &result_array)
    : SearchVectorTaskBase(table_name, file_id_array, query_record_array, query_range_array, top_k),
      result_array_(result_array) {

}

BaseTaskPtr SearchVectorTask2::Create(const std::string& table_name,
                                     const std::vector<std::string>& file_id_array,
                                     const std::vector<thrift::RowRecord> & query_record_array,
                                     const std::vector<thrift::Range> & query_range_array,
                                     const int64_t top_k,
                                     std::vector<thrift::TopKQueryBinResult>& result_array) {
    return std::shared_ptr<BaseTask>(new SearchVectorTask2(table_name, file_id_array,
            query_record_array, query_range_array, top_k, result_array));
}

ServerError SearchVectorTask2::ConstructResult(engine::QueryResults& results) {
    for(size_t i = 0; i < results.size(); i++) {
        auto& result = results[i];

        thrift::TopKQueryBinResult thrift_topk_result;
        if(result.empty()) {
G
groot 已提交
642
            result_array_.emplace_back(thrift_topk_result);
643
            continue;
G
groot 已提交
644
        }
645

646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
        std::string str_ids, str_distances;
        str_ids.resize(sizeof(engine::IDNumber)*result.size());
        str_distances.resize(sizeof(double)*result.size());

        engine::IDNumber* ids_ptr = (engine::IDNumber*)str_ids.data();
        double* distance_ptr = (double*)str_distances.data();
        for(size_t k = 0; k < results.size(); k++) {
            auto& pair = result[k];
            ids_ptr[k] = pair.first;
            distance_ptr[k] = pair.second;
        }

        thrift_topk_result.__set_id_array(str_ids);
        thrift_topk_result.__set_distance_array(str_distances);
        result_array_.emplace_back(thrift_topk_result);
G
groot 已提交
661 662 663 664 665
    }

    return SERVER_SUCCESS;
}

G
groot 已提交
666 667
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetTableRowCountTask::GetTableRowCountTask(const std::string& table_name, int64_t& row_count)
G
groot 已提交
668
: BaseTask(DDL_DML_TASK_GROUP),
G
groot 已提交
669 670 671 672 673 674 675 676 677 678
  table_name_(table_name),
  row_count_(row_count) {

}

BaseTaskPtr GetTableRowCountTask::Create(const std::string& table_name, int64_t& row_count) {
    return std::shared_ptr<BaseTask>(new GetTableRowCountTask(table_name, row_count));
}

ServerError GetTableRowCountTask::OnExecute() {
G
groot 已提交
679 680 681
    try {
        TimeRecorder rc("GetTableRowCountTask");

G
groot 已提交
682
        //step 1: check arguments
J
jinhai 已提交
683 684 685
        ServerError res = SERVER_SUCCESS;
        res = ValidateTableName(table_name_);
        if(res != SERVER_SUCCESS) {
686
            return SetError(res, "Invalid table name: " + table_name_);
G
groot 已提交
687 688 689 690
        }

        //step 2: get row count
        uint64_t row_count = 0;
G
groot 已提交
691
        engine::Status stat = DBWrapper::DB()->GetTableRowCount(table_name_, row_count);
G
groot 已提交
692
        if (!stat.ok()) {
G
groot 已提交
693
            return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
G
groot 已提交
694 695 696 697
        }

        row_count_ = (int64_t) row_count;

Z
fix  
zhiru 已提交
698
        rc.Elapse("total cost");
G
groot 已提交
699 700

    } catch (std::exception& ex) {
G
groot 已提交
701
        return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
G
groot 已提交
702 703
    }

G
groot 已提交
704
    return SERVER_SUCCESS;
G
groot 已提交
705 706
}

G
groot 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719 720
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
PingTask::PingTask(const std::string& cmd, std::string& result)
    : BaseTask(PING_TASK_GROUP),
      cmd_(cmd),
      result_(result) {

}

BaseTaskPtr PingTask::Create(const std::string& cmd, std::string& result) {
    return std::shared_ptr<BaseTask>(new PingTask(cmd, result));
}

ServerError PingTask::OnExecute() {
    if(cmd_ == "version") {
G
groot 已提交
721
        result_ = MILVUS_VERSION;
G
groot 已提交
722 723 724 725 726
    }

    return SERVER_SUCCESS;
}

G
groot 已提交
727 728 729
}
}
}