MegasearchTask.cpp 16.8 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
#include "MegasearchTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"
#include "db/DB.h"
#include "db/Env.h"
#include "db/Meta.h"


namespace zilliz {
namespace vecwise {
namespace server {

static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
G
groot 已提交
24
static const std::string PING_TASK_GROUP = "ping";
G
groot 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;

using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;

namespace {
    class DBWrapper {
    public:
        DBWrapper() {
            zilliz::vecwise::engine::Options opt;
            ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
            opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
            std::string db_path = config.GetValue(CONFIG_DB_PATH);
            opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
            opt.meta.path = db_path + "/db";

            CommonUtil::CreateDirectory(opt.meta.path);

            zilliz::vecwise::engine::DB::Open(opt, &db_);
            if(db_ == nullptr) {
                SERVER_LOG_ERROR << "Failed to open db";
                throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
            }
        }

G
groot 已提交
52 53 54 55
        ~DBWrapper() {
            delete db_;
        }

G
groot 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        zilliz::vecwise::engine::DB* DB() { return db_; }

    private:
        zilliz::vecwise::engine::DB* db_ = nullptr;
    };

    zilliz::vecwise::engine::DB* DB() {
        static DBWrapper db_wrapper;
        return db_wrapper.DB();
    }

    ThreadPool& GetThreadPool() {
        static ThreadPool pool(6);
        return pool;
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const thrift::TableSchema& schema)
: BaseTask(DDL_DML_TASK_GROUP),
  schema_(schema) {

}

BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
    return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
}

ServerError CreateTableTask::OnExecute() {
    TimeRecorder rc("CreateTableTask");
G
groot 已提交
86
    
G
groot 已提交
87 88 89 90 91 92
    try {
        if(schema_.vector_column_array.empty()) {
            return SERVER_INVALID_ARGUMENT;
        }

        IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
G
groot 已提交
93 94 95 96
        engine::meta::TableSchema table_info;
        table_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
        table_info.table_id = schema_.table_name;
        engine::Status stat = DB()->CreateTable(table_info);
G
groot 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        if(!stat.ok()) {//could exist
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return SERVER_SUCCESS;
        }

    } catch (std::exception& ex) {
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return SERVER_UNEXPECTED_ERROR;
    }

    rc.Record("done");

    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
G
groot 已提交
117
    : BaseTask(PING_TASK_GROUP),
G
groot 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130
      table_name_(table_name),
      schema_(schema) {
    schema_.table_name = table_name_;
}

BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, thrift::TableSchema& schema) {
    return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
}

ServerError DescribeTableTask::OnExecute() {
    TimeRecorder rc("DescribeTableTask");

    try {
G
groot 已提交
131 132 133
        engine::meta::TableSchema table_info;
        table_info.table_id = table_name_;
        engine::Status stat = DB()->DescribeTable(table_info);
G
groot 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        if(!stat.ok()) {
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        } else {

        }

    } catch (std::exception& ex) {
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return SERVER_UNEXPECTED_ERROR;
    }

    rc.Record("done");

    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTableTask::DeleteTableTask(const std::string& table_name)
    : BaseTask(DDL_DML_TASK_GROUP),
      table_name_(table_name) {

}

G
groot 已提交
162 163
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
    return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
G
groot 已提交
164 165 166 167 168 169 170
}

ServerError DeleteTableTask::OnExecute() {
    error_code_ = SERVER_NOT_IMPLEMENT;
    error_msg_ = "delete table not implemented";
    SERVER_LOG_ERROR << error_msg_;

G
groot 已提交
171
    IVecIdMapper::GetInstance()->DeleteGroup(table_name_);
G
groot 已提交
172 173 174 175

    return SERVER_NOT_IMPLEMENT;
}

G
groot 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTablePartitionTask::CreateTablePartitionTask(const thrift::CreateTablePartitionParam &param)
    : BaseTask(DDL_DML_TASK_GROUP),
      param_(param) {

}

BaseTaskPtr CreateTablePartitionTask::Create(const thrift::CreateTablePartitionParam &param) {
    return std::shared_ptr<BaseTask>(new CreateTablePartitionTask(param));
}

ServerError CreateTablePartitionTask::OnExecute() {
    error_code_ = SERVER_NOT_IMPLEMENT;
    error_msg_ = "create table partition not implemented";
    SERVER_LOG_ERROR << error_msg_;

    return SERVER_NOT_IMPLEMENT;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTablePartitionTask::DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam &param)
    : BaseTask(DDL_DML_TASK_GROUP),
      param_(param) {

}

BaseTaskPtr DeleteTablePartitionTask::Create(const thrift::DeleteTablePartitionParam &param) {
    return std::shared_ptr<BaseTask>(new DeleteTablePartitionTask(param));
}

ServerError DeleteTablePartitionTask::OnExecute() {
    error_code_ = SERVER_NOT_IMPLEMENT;
    error_msg_ = "delete table partition not implemented";
    SERVER_LOG_ERROR << error_msg_;

    return SERVER_NOT_IMPLEMENT;
}

G
groot 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ShowTablesTask::ShowTablesTask(std::vector<std::string>& tables)
    : BaseTask(PING_TASK_GROUP),
      tables_(tables) {

}

BaseTaskPtr ShowTablesTask::Create(std::vector<std::string>& tables) {
    return std::shared_ptr<BaseTask>(new ShowTablesTask(tables));
}

ServerError ShowTablesTask::OnExecute() {
    IVecIdMapper::GetInstance()->AllGroups(tables_);

    return SERVER_SUCCESS;
}
G
groot 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
                                       const std::vector<thrift::RowRecord>& record_array,
                                       std::vector<int64_t>& record_ids)
    : BaseTask(DDL_DML_TASK_GROUP),
      table_name_(table_name),
      record_array_(record_array),
      record_ids_(record_ids) {
    record_ids_.clear();
}

BaseTaskPtr AddVectorTask::Create(const std::string& table_name,
                                       const std::vector<thrift::RowRecord>& record_array,
                                       std::vector<int64_t>& record_ids) {
    return std::shared_ptr<BaseTask>(new AddVectorTask(table_name, record_array, record_ids));
}

ServerError AddVectorTask::OnExecute() {
    try {
        TimeRecorder rc("AddVectorTask");

        if(record_array_.empty()) {
            return SERVER_SUCCESS;
        }

G
groot 已提交
256 257 258
        engine::meta::TableSchema table_info;
        table_info.table_id = table_name_;
        engine::Status stat = DB()->DescribeTable(table_info);
G
groot 已提交
259 260 261 262 263 264 265 266 267 268
        if(!stat.ok()) {
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

        rc.Record("get group info");

        uint64_t vec_count = (uint64_t)record_array_.size();
G
groot 已提交
269
        uint64_t group_dim = table_info.dimension;
G
groot 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
        std::vector<float> vec_f;
        vec_f.resize(vec_count*group_dim);//allocate enough memory
        for(uint64_t i = 0; i < vec_count; i++) {
            const auto& record = record_array_[i];
            if(record.vector_map.empty()) {
                error_code_ = SERVER_INVALID_ARGUMENT;
                error_msg_ = "No vector provided in record";
                SERVER_LOG_ERROR << error_msg_;
                return error_code_;
            }
            uint64_t vec_dim = record.vector_map.begin()->second.size()/sizeof(double);//how many double value?
            if(vec_dim != group_dim) {
                SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
                                 << " vs. group dimension:" << group_dim;
                error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
                error_msg_ = "Engine failed: " + stat.ToString();
                return error_code_;
            }

G
groot 已提交
289
            //convert double array to float array(thrift has no float type)
G
groot 已提交
290 291 292 293 294 295 296 297
            const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
            for(uint64_t d = 0; d < vec_dim; d++) {
                vec_f[i*vec_dim + d] = (float)(d_p[d]);
            }
        }

        rc.Record("prepare vectors data");

X
Xu Peng 已提交
298
        stat = DB()->InsertVectors(table_name_, vec_count, vec_f.data(), record_ids_);
G
groot 已提交
299 300 301 302 303 304 305 306
        rc.Record("add vectors to engine");
        if(!stat.ok()) {
            error_code_ = SERVER_UNEXPECTED_ERROR;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

G
groot 已提交
307
        if(record_ids_.size() != vec_count) {
G
groot 已提交
308 309 310 311
            SERVER_LOG_ERROR << "Vector ID not returned";
            return SERVER_UNEXPECTED_ERROR;
        }

G
groot 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
        //persist attributes
        for(uint64_t i = 0; i < vec_count; i++) {
            const auto &record = record_array_[i];

            //any attributes?
            if(record.attribute_map.empty()) {
                continue;
            }

            std::string nid = std::to_string(record_ids_[i]);
            std::string attrib_str;
            AttributeSerializer::Encode(record.attribute_map, attrib_str);
            IVecIdMapper::GetInstance()->Put(nid, attrib_str, table_name_);
        }

        rc.Record("persist vector attributes");
G
groot 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369

    } catch (std::exception& ex) {
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
    }

    return SERVER_SUCCESS;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& table_name,
                                   const int64_t top_k,
                                   const std::vector<thrift::QueryRecord>& record_array,
                                   std::vector<thrift::TopKQueryResult>& result_array)
    : BaseTask(DQL_TASK_GROUP),
      table_name_(table_name),
      top_k_(top_k),
      record_array_(record_array),
      result_array_(result_array) {

}

BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
                                     const std::vector<thrift::QueryRecord>& record_array,
                                     const int64_t top_k,
                                     std::vector<thrift::TopKQueryResult>& result_array) {
    return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, top_k, record_array, result_array));
}

ServerError SearchVectorTask::OnExecute() {
    try {
        TimeRecorder rc("SearchVectorTask");

        if(top_k_ <= 0 || record_array_.empty()) {
            error_code_ = SERVER_INVALID_ARGUMENT;
            error_msg_ = "Invalid topk value, or query record array is empty";
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

G
groot 已提交
370 371 372
        engine::meta::TableSchema table_info;
        table_info.table_id = table_name_;
        engine::Status stat = DB()->DescribeTable(table_info);
G
groot 已提交
373 374 375 376 377 378 379 380 381
        if(!stat.ok()) {
            error_code_ = SERVER_GROUP_NOT_EXIST;
            error_msg_ = "Engine failed: " + stat.ToString();
            SERVER_LOG_ERROR << error_msg_;
            return error_code_;
        }

        std::vector<float> vec_f;
        uint64_t record_count = (uint64_t)record_array_.size();
G
groot 已提交
382
        vec_f.resize(record_count*table_info.dimension);
G
groot 已提交
383 384 385 386 387 388 389 390 391 392 393

        for(uint64_t i = 0; i < record_array_.size(); i++) {
            const auto& record = record_array_[i];
            if (record.vector_map.empty()) {
                error_code_ = SERVER_INVALID_ARGUMENT;
                error_msg_ = "Query record has no vector";
                SERVER_LOG_ERROR << error_msg_;
                return error_code_;
            }

            uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
G
groot 已提交
394
            if (vec_dim != table_info.dimension) {
G
groot 已提交
395
                SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
G
groot 已提交
396
                                 << " vs. group dimension:" << table_info.dimension;
G
groot 已提交
397 398 399 400 401
                error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
                error_msg_ = "Engine failed: " + stat.ToString();
                return error_code_;
            }

G
groot 已提交
402
            //convert double array to float array(thrift has no float type)
G
groot 已提交
403 404 405 406 407 408 409 410 411 412
            const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
            for(uint64_t d = 0; d < vec_dim; d++) {
                vec_f[i*vec_dim + d] = (float)(d_p[d]);
            }
        }

        rc.Record("prepare vector data");

        std::vector<DB_DATE> dates;
        engine::QueryResults results;
X
Xu Peng 已提交
413
        stat = DB()->Query(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
G
groot 已提交
414
        rc.Record("search vectors from engine");
G
groot 已提交
415 416 417
        if(!stat.ok()) {
            SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
            return SERVER_UNEXPECTED_ERROR;
G
groot 已提交
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
        }

        if(results.size() != record_count) {
            SERVER_LOG_ERROR << "Search result not returned";
            return SERVER_UNEXPECTED_ERROR;
        }

        //construct result array
        for(uint64_t i = 0; i < record_count; i++) {
            auto& result = results[i];
            const auto& record = record_array_[i];

            thrift::TopKQueryResult thrift_topk_result;
            for(auto id : result) {
                thrift::QueryResult thrift_result;
                thrift_result.__set_id(id);

                //need get attributes?
                if(record.selected_column_array.empty()) {
G
groot 已提交
437
                    thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
G
groot 已提交
438 439 440 441 442 443 444 445 446 447
                    continue;
                }

                std::string nid = std::to_string(id);
                std::string attrib_str;
                IVecIdMapper::GetInstance()->Get(nid, attrib_str, table_name_);

                AttribMap attrib_map;
                AttributeSerializer::Decode(attrib_str, attrib_map);

G
groot 已提交
448 449
                for(auto& attribute : record.selected_column_array) {
                    thrift_result.column_map[attribute] = attrib_map[attribute];
G
groot 已提交
450 451
                }

G
groot 已提交
452
                thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
G
groot 已提交
453 454
            }

G
groot 已提交
455 456 457
            result_array_.emplace_back(thrift_topk_result);
        }
        rc.Record("construct result");
G
groot 已提交
458
        rc.Elapse("totally cost");
G
groot 已提交
459 460 461 462 463 464 465 466 467 468
    } catch (std::exception& ex) {
        error_code_ = SERVER_UNEXPECTED_ERROR;
        error_msg_ = ex.what();
        SERVER_LOG_ERROR << error_msg_;
        return error_code_;
    }

    return SERVER_SUCCESS;
}

G
groot 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
PingTask::PingTask(const std::string& cmd, std::string& result)
    : BaseTask(PING_TASK_GROUP),
      cmd_(cmd),
      result_(result) {

}

BaseTaskPtr PingTask::Create(const std::string& cmd, std::string& result) {
    return std::shared_ptr<BaseTask>(new PingTask(cmd, result));
}

ServerError PingTask::OnExecute() {
    if(cmd_ == "version") {
        result_ = "v1.2.0";//currently hardcode
    }

    return SERVER_SUCCESS;
}

G
groot 已提交
489 490 491
}
}
}