ClientProxy.cpp 9.9 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
#include "ClientProxy.h"

G
groot 已提交
8
namespace milvus {
G
groot 已提交
9 10 11 12 13 14 15 16 17

std::shared_ptr<ThriftClient>&
ClientProxy::ClientPtr() const {
    if(client_ptr == nullptr) {
        client_ptr = std::make_shared<ThriftClient>();
    }
    return client_ptr;
}

G
groot 已提交
18 19 20 21
bool ClientProxy::IsConnected() const {
    return (client_ptr != nullptr && connected_);
}

G
groot 已提交
22 23 24 25 26
Status
ClientProxy::Connect(const ConnectParam &param) {
    Disconnect();

    int32_t port = atoi(param.port.c_str());
G
groot 已提交
27 28 29 30 31 32
    Status status = ClientPtr()->Connect(param.ip_address, port, THRIFT_PROTOCOL_BINARY);
    if(status.ok()) {
        connected_ = true;
    }

    return status;
G
groot 已提交
33 34 35 36 37 38
}

Status
ClientProxy::Connect(const std::string &uri) {
    Disconnect();

G
groot 已提交
39 40 41 42 43 44 45 46 47 48
    size_t index = uri.find_first_of(":", 0);
    if((index == std::string::npos)) {
        return Status::Invalid("Invalid uri");
    }

    ConnectParam param;
    param.ip_address = uri.substr(0, index);
    param.port = uri.substr(index + 1);

    return Connect(param);
G
groot 已提交
49 50 51 52
}

Status
ClientProxy::Connected() const {
G
groot 已提交
53 54
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
55 56 57 58 59 60
    }

    try {
        std::string info;
        ClientPtr()->interface()->Ping(info, "");
    }  catch ( std::exception& ex) {
G
groot 已提交
61
        return Status(StatusCode::NotConnected, "connection lost: " + std::string(ex.what()));
G
groot 已提交
62 63 64 65 66 67 68
    }

    return Status::OK();
}

Status
ClientProxy::Disconnect() {
G
groot 已提交
69 70
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
71 72
    }

G
groot 已提交
73
    connected_ = false;
G
groot 已提交
74 75 76 77 78
    return ClientPtr()->Disconnect();
}

std::string
ClientProxy::ClientVersion() const {
G
groot 已提交
79
    return "";
G
groot 已提交
80 81 82 83
}

Status
ClientProxy::CreateTable(const TableSchema &param) {
G
groot 已提交
84 85
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
86 87 88
    }

    try {
K
kun yu 已提交
89

G
groot 已提交
90 91
        thrift::TableSchema schema;
        schema.__set_table_name(param.table_name);
G
groot 已提交
92 93 94
        schema.__set_index_type((int)param.index_type);
        schema.__set_dimension(param.dimension);
        schema.__set_store_raw_vector(param.store_raw_vector);
G
groot 已提交
95 96 97 98 99 100 101 102 103 104

        ClientPtr()->interface()->CreateTable(schema);

    }  catch ( std::exception& ex) {
        return Status(StatusCode::UnknownError, "failed to create table: " + std::string(ex.what()));
    }

    return Status::OK();
}

G
groot 已提交
105 106 107 108 109 110 111 112 113
bool
ClientProxy::HasTable(const std::string &table_name) {
    if(!IsConnected()) {
        return false;
    }

    return ClientPtr()->interface()->HasTable(table_name);
}

G
groot 已提交
114
Status
S
starlord 已提交
115
ClientProxy::DropTable(const std::string &table_name) {
G
groot 已提交
116 117
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
118 119 120 121 122 123 124 125 126 127 128 129
    }

    try {
        ClientPtr()->interface()->DeleteTable(table_name);

    }  catch ( std::exception& ex) {
        return Status(StatusCode::UnknownError, "failed to delete table: " + std::string(ex.what()));
    }

    return Status::OK();
}

K
kun yu 已提交
130
Status
S
starlord 已提交
131
ClientProxy::CreateIndex(const IndexParam &index_param) {
132 133 134 135 136
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
    }

    try {
S
starlord 已提交
137
        ClientPtr()->interface()->BuildIndex(index_param.table_name);
138 139 140 141 142 143 144 145

    }  catch ( std::exception& ex) {
        return Status(StatusCode::UnknownError, "failed to build index: " + std::string(ex.what()));
    }

    return Status::OK();
}

G
groot 已提交
146
Status
S
starlord 已提交
147
ClientProxy::Insert(const std::string &table_name,
G
groot 已提交
148 149
                          const std::vector<RowRecord> &record_array,
                          std::vector<int64_t> &id_array) {
G
groot 已提交
150 151
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
152 153 154 155 156 157
    }

    try {
        std::vector<thrift::RowRecord> thrift_records;
        for(auto& record : record_array) {
            thrift::RowRecord thrift_record;
G
groot 已提交
158 159 160 161 162

            thrift_record.vector_data.resize(record.data.size() * sizeof(double));
            double *dbl = (double *) (const_cast<char *>(thrift_record.vector_data.data()));
            for (size_t i = 0; i < record.data.size(); i++) {
                dbl[i] = (double) (record.data[i]);
G
groot 已提交
163
            }
G
groot 已提交
164

G
groot 已提交
165 166 167 168 169 170 171 172 173 174 175
            thrift_records.emplace_back(thrift_record);
        }
        ClientPtr()->interface()->AddVector(id_array, table_name, thrift_records);

    }  catch ( std::exception& ex) {
        return Status(StatusCode::UnknownError, "failed to add vector: " + std::string(ex.what()));
    }

    return Status::OK();
}

K
kun yu 已提交
176
Status
S
starlord 已提交
177
ClientProxy::Search(const std::string &table_name,
G
groot 已提交
178 179 180 181
                          const std::vector<RowRecord> &query_record_array,
                          const std::vector<Range> &query_range_array,
                          int64_t topk,
                          std::vector<TopKQueryResult> &topk_query_result_array) {
G
groot 已提交
182 183
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
184 185 186
    }

    try {
G
groot 已提交
187 188

        //step 1: convert vectors data
G
groot 已提交
189
        std::vector<thrift::RowRecord> thrift_records;
G
groot 已提交
190
        for(auto& record : query_record_array) {
G
groot 已提交
191 192 193
            thrift::RowRecord thrift_record;

            thrift_record.vector_data.resize(record.data.size() * sizeof(double));
K
kun yu 已提交
194
            auto dbl = (double *) (const_cast<char *>(thrift_record.vector_data.data()));
G
groot 已提交
195 196
            for (size_t i = 0; i < record.data.size(); i++) {
                dbl[i] = (double) (record.data[i]);
G
groot 已提交
197
            }
G
groot 已提交
198

G
groot 已提交
199 200 201
            thrift_records.emplace_back(thrift_record);
        }

G
groot 已提交
202
        //step 2: convert range array
G
groot 已提交
203
        std::vector<thrift::Range> thrift_ranges;
G
groot 已提交
204 205 206 207 208 209 210 211 212
        for(auto& range : query_range_array) {
            thrift::Range thrift_range;
            thrift_range.__set_start_value(range.start_value);
            thrift_range.__set_end_value(range.end_value);

            thrift_ranges.emplace_back(thrift_range);
        }

        //step 3: search vectors
213 214
        std::vector<thrift::TopKQueryBinResult> result_array;
        ClientPtr()->interface()->SearchVector2(result_array, table_name, thrift_records, thrift_ranges, topk);
G
groot 已提交
215

G
groot 已提交
216
        //step 4: convert result array
G
groot 已提交
217 218 219
        for(auto& thrift_topk_result : result_array) {
            TopKQueryResult result;

220 221 222 223 224 225
            size_t id_count = thrift_topk_result.id_array.size()/sizeof(int64_t);
            size_t dist_count = thrift_topk_result.distance_array.size()/ sizeof(double);
            if(id_count != dist_count) {
                return Status(StatusCode::UnknownError, "illegal result");
            }

K
kun yu 已提交
226 227
            auto id_ptr = (int64_t*)thrift_topk_result.id_array.data();
            auto dist_ptr = (double*)thrift_topk_result.distance_array.data();
228
            for(size_t i = 0; i < id_count; i++) {
G
groot 已提交
229
                QueryResult query_result;
230 231
                query_result.id = id_ptr[i];
                query_result.distance = dist_ptr[i];
G
groot 已提交
232 233 234 235 236 237 238
                result.query_result_arrays.emplace_back(query_result);
            }

            topk_query_result_array.emplace_back(result);
        }

    }  catch ( std::exception& ex) {
G
groot 已提交
239
        return Status(StatusCode::UnknownError, "failed to search vectors: " + std::string(ex.what()));
G
groot 已提交
240 241 242 243 244 245 246
    }

    return Status::OK();
}

Status
ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
G
groot 已提交
247 248
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
249 250 251 252 253 254 255
    }

    try {
        thrift::TableSchema thrift_schema;
        ClientPtr()->interface()->DescribeTable(thrift_schema, table_name);

        table_schema.table_name = thrift_schema.table_name;
G
groot 已提交
256
        table_schema.index_type = (IndexType)thrift_schema.index_type;
G
groot 已提交
257 258
        table_schema.dimension = thrift_schema.dimension;
        table_schema.store_raw_vector = thrift_schema.store_raw_vector;
G
groot 已提交
259

G
groot 已提交
260 261 262
    }  catch ( std::exception& ex) {
        return Status(StatusCode::UnknownError, "failed to describe table: " + std::string(ex.what()));
    }
G
groot 已提交
263

G
groot 已提交
264 265 266 267
    return Status::OK();
}

Status
S
starlord 已提交
268
ClientProxy::CountTable(const std::string &table_name, int64_t &row_count) {
G
groot 已提交
269 270 271 272 273 274
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
    }

    try {
        row_count = ClientPtr()->interface()->GetTableRowCount(table_name);
G
groot 已提交
275 276

    }  catch ( std::exception& ex) {
G
groot 已提交
277
        return Status(StatusCode::UnknownError, "failed to show tables: " + std::string(ex.what()));
G
groot 已提交
278 279 280 281 282 283 284
    }

    return Status::OK();
}

Status
ClientProxy::ShowTables(std::vector<std::string> &table_array) {
G
groot 已提交
285 286
    if(!IsConnected()) {
        return Status(StatusCode::NotConnected, "not connected to server");
G
groot 已提交
287 288 289 290 291 292 293 294 295 296 297 298 299 300
    }

    try {
        ClientPtr()->interface()->ShowTables(table_array);

    }  catch ( std::exception& ex) {
        return Status(StatusCode::UnknownError, "failed to show tables: " + std::string(ex.what()));
    }

    return Status::OK();
}

std::string
ClientProxy::ServerVersion() const {
G
groot 已提交
301
    if(!IsConnected()) {
G
groot 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315
        return "";
    }

    try {
        std::string version;
        ClientPtr()->interface()->Ping(version, "version");
        return version;
    }  catch ( std::exception& ex) {
        return "";
    }
}

std::string
ClientProxy::ServerStatus() const {
G
groot 已提交
316 317
    if(!IsConnected()) {
        return "not connected to server";
G
groot 已提交
318 319 320 321 322 323 324 325 326 327
    }

    try {
        std::string dummy;
        ClientPtr()->interface()->Ping(dummy, "");
        return "server alive";
    }  catch ( std::exception& ex) {
        return "connection lost";
    }
}
S
starlord 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345

Status ClientProxy::DeleteByRange(Range &range, const std::string &table_name) {
    return Status::OK();
}

Status ClientProxy::PreloadTable(const std::string &table_name) const {
    return Status::OK();
}

IndexParam ClientProxy::DescribeIndex(const std::string &table_name) const {
    IndexParam index_param;
    index_param.table_name = table_name;
    return index_param;
}

Status ClientProxy::DropIndex(const std::string &table_name) const {
    return Status::OK();
}
G
groot 已提交
346 347
    
}