提交 fea6c144 编写于 作者: J JLY2015 提交者: ob-robot

[vector index] fix hybird index not utf8 error

上级 d74c069e
......@@ -504,6 +504,7 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
storage::ObTableScanParam *&table_scan_param = task_ctx->table_scan_param_;
schema::ObTableParam *&table_param = task_ctx->table_param_;
storage::ObValueRowIterator &delta_delete_iter = task_ctx->delta_delete_iter_;
ObCollationType col_type = CS_TYPE_INVALID;
int64_t dim = 0;
int64_t loop_cnt = 0;
uint64_t timeout_us = ObTimeUtility::current_time() + ObInsertLobColumnHelper::LOB_TX_TIMEOUT;
......@@ -549,6 +550,8 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
LOG_WARN("failed to get index id table column ids", K(ret), K(adaptor));
} else if (task_ctx->embedded_table_column_ids_.empty() && OB_FAIL(get_embedded_table_column_ids(adaptor))) {
LOG_WARN("failed to get embedded table column ids", K(ret), K(adaptor));
} else if (OB_FAIL(ObVectorIndexUtil::get_index_column_collation_type(tenant_id_, adaptor.get_embedded_table_id(), col_type))) {
LOG_WARN("failed to get chunc column col_type", K(ret), K(adaptor));
}
int cur_row_count = 0;
......@@ -635,6 +638,7 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
const ObAiModelEndpointInfo *endpoint = task_ctx->endpoint_; // endpoint should not be null after init.
task_ctx->embedding_task_ = new(task_buf)ObEmbeddingTask(task_ctx->allocator_);
ObPluginVectorIndexService *service = MTL(ObPluginVectorIndexService *);
if (OB_ISNULL(service)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected null ptr", K(ret), KPC(service));
......@@ -643,7 +647,7 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
} else if (OB_FAIL(ob_write_string(task_ctx->allocator_, endpoint->get_url(), url, true))) {
LOG_WARN("fail to write string", K(ret));
} else if (OB_FAIL(task_ctx->embedding_task_->init(url, endpoint->get_request_model_name(),
endpoint->get_provider(), access_key, chunk_array, dim, timeout_us))) {
endpoint->get_provider(), access_key, chunk_array, col_type, dim, timeout_us))) {
LOG_WARN("failed to init embedding task", K(ret), KPC(endpoint));
} else {
ObEmbeddingTaskHandler *embedding_handler = nullptr;
......
......@@ -191,7 +191,8 @@ ObEmbeddingTask::ObEmbeddingTask() : local_allocator_("EmbeddingTask", OB_MALLOC
successful_requests_count_(0),
task_cond_(),
callback_done_(false),
ref_cnt_(0) {}
ref_cnt_(0),
col_type_(CS_TYPE_INVALID) {}
ObEmbeddingTask::ObEmbeddingTask(ObArenaAllocator &allocator) : local_allocator_(), allocator_(allocator),
model_url_(),
model_name_(),
......@@ -238,7 +239,8 @@ ObEmbeddingTask::ObEmbeddingTask(ObArenaAllocator &allocator) : local_allocator_
successful_requests_count_(0),
task_cond_(),
callback_done_(false),
ref_cnt_(0) {}
ref_cnt_(0),
col_type_(CS_TYPE_INVALID) {}
ObEmbeddingTask::~ObEmbeddingTask() {
reset();
}
......@@ -248,6 +250,7 @@ int ObEmbeddingTask::init(const ObString &model_url,
const ObString &provider,
const ObString &user_key,
const ObIArray<ObString> &input_chunks,
const ObCollationType col_type,
int64_t dimension,
int64_t http_timeout_us,
storage::ObEmbeddingIOCallbackHandle *cb_handle)
......@@ -304,7 +307,7 @@ int ObEmbeddingTask::init(const ObString &model_url,
batch_size_adjusted_ = false;
current_batch_size_ = batch_size_;
successful_requests_count_ = 0;
col_type_ = col_type;
LOG_DEBUG("task initialized successfully", K(user_key_), K(task_id_), K(dimension_));
}
......@@ -513,11 +516,21 @@ int ObEmbeddingTask::start_async_work()
} else if (OB_FAIL(json_builder.add_array_field(root, INPUT_NAME, input_array))) {
LOG_WARN("failed to add input array field", K(ret));
} else {
ObString new_utf8_text;
for (int64_t i = start_idx; i < end_idx && OB_SUCC(ret); i++) {
const ObString &text = input_chunks_.at(i);
LOG_DEBUG("Adding text to input array", K(i), K(text));
if (OB_FAIL(json_builder.array_add_string(input_array, text))) {
LOG_WARN("failed to add text to input array", K(ret), K(i));
if (col_type_ != CS_TYPE_UTF8MB4_BIN && col_type_ != CS_TYPE_UTF8MB4_GENERAL_CI) {
if (col_type_ == CS_TYPE_INVALID) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected cs_type", K(ret), K(col_type_));
} else if (OB_FAIL(ObCharset::charset_convert(allocator_, text, col_type_, CS_TYPE_UTF8MB4_GENERAL_CI, new_utf8_text))) {
LOG_WARN("charset convertion failed", K(ret), K(text));
}
}
if (OB_FAIL(ret)) {
} else if (OB_FAIL(json_builder.array_add_string(input_array, new_utf8_text))) {
LOG_WARN("failed to add new_utf8_text to input array", K(ret), K(i));
} else {
total_text_length += text.length();
}
......
......@@ -165,6 +165,7 @@ class ObEmbeddingTask
const ObString &provider,
const ObString &user_key,
const ObIArray<ObString> &input_chunks,
const ObCollationType col_type,
int64_t dimension,
int64_t http_timeout_us,
storage::ObEmbeddingIOCallbackHandle *cb_handle = nullptr);
......@@ -185,7 +186,8 @@ class ObEmbeddingTask
K_(batch_size),
K_(processed_chunks),
K_(total_chunks),
K_(process_callback_offset));
K_(process_callback_offset),
K_(col_type));
bool is_completed();
void retain_if_managed();
void release_if_managed();
......@@ -370,6 +372,7 @@ private:
// TODO(fanfangyao.ffy): use taskhandle to manage task reference count
// ref_cnt_ is only used to track the reference count of the post create embedding task
int64_t ref_cnt_;
ObCollationType col_type_;
private:
DISALLOW_COPY_AND_ASSIGN(ObEmbeddingTask);
......
......@@ -3727,6 +3727,70 @@ int ObVectorIndexUtil::check_index_param(
return ret;
}
// index_table_id must be table which has vector column
int ObVectorIndexUtil::get_index_column_collation_type(
const int64_t tenant_id,
const uint64_t index_table_id,
ObCollationType &col_type)
{
int ret = OB_SUCCESS;
const ObTableSchema *data_table_schema = nullptr;
const ObTableSchema *table_schema = nullptr;
int64_t main_table_id = OB_INVALID_ID;
ObArray<uint64_t> tmp_column_ids;
col_type = CS_TYPE_INVALID;
ObSchemaGetterGuard schema_guard;
if (!is_valid_tenant_id(tenant_id) || OB_INVALID_ID == index_table_id) {
ret = OB_INVALID_ARGUMENT;
LOG_WARN("invalid argument", K(ret), K(index_table_id));
} else if (OB_FAIL(ObMultiVersionSchemaService::get_instance().get_tenant_schema_guard(tenant_id, schema_guard))) {
LOG_WARN("fail to get tenant schema guard", K(ret), K(MTL_ID()));
} else if (OB_FAIL(schema_guard.get_table_schema(tenant_id, index_table_id, table_schema))) {
LOG_WARN("fail to get table scheam", K(ret), K(tenant_id), K(index_table_id));
} else if (OB_ISNULL(table_schema)) {
ret = OB_TABLE_NOT_EXIST;
LOG_INFO("table not exit", K(ret), K(tenant_id), K(index_table_id));
} else if (OB_FALSE_IT(main_table_id = table_schema->get_data_table_id())) {
} else if (OB_INVALID_ID == main_table_id) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected invalid id", K(ret), K(main_table_id));
} else if (OB_FAIL(schema_guard.get_table_schema(tenant_id, main_table_id, data_table_schema))) {
LOG_WARN("fail to get table scheam", K(ret), K(tenant_id), K(index_table_id));
} else if (OB_ISNULL(data_table_schema)) {
ret = OB_TABLE_NOT_EXIST;
LOG_INFO("table not exit", K(ret), K(tenant_id), K(main_table_id));
} else if (OB_FAIL(table_schema->get_column_ids(tmp_column_ids))) {
LOG_WARN("fail to get index table all column ids", K(ret), K(data_table_schema));
} else {
for (int64_t i = 0; OB_SUCC(ret) && i < tmp_column_ids.count() && col_type == CS_TYPE_INVALID; ++i) {
const ObColumnSchemaV2 *col_schema = data_table_schema->get_column_schema(tmp_column_ids[i]);
if (OB_ISNULL(col_schema)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected null column schema ptr", K(ret));
} else if (!col_schema->is_vec_hnsw_vector_column()) {
// only need vector column
} else {
ObArray<uint64_t> cascaded_column_ids;
if (OB_FAIL(col_schema->get_cascaded_column_ids(cascaded_column_ids))) {
LOG_WARN("failed to get cascaded column ids", K(ret));
} else {
for (int64_t j = 0; OB_SUCC(ret) && j < cascaded_column_ids.count() && col_type == CS_TYPE_INVALID; ++j) {
const ObColumnSchemaV2 *cascaded_column = NULL;
if (OB_ISNULL(cascaded_column = data_table_schema->get_column_schema(cascaded_column_ids.at(j)))) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("unexpected cascaded column", K(ret));
} else {
col_type = cascaded_column->get_collation_type();
LOG_DEBUG("get vector index collation type", K(col_type));
}
}
}
}
}
}
return ret;
}
int ObVectorIndexUtil::get_vector_index_type(
sql::ObRawExpr *&raw_expr,
......
......@@ -847,6 +847,9 @@ public:
const ObTableSchema &index_table_schema,
bool &need_embedding_when_rebuild);
static bool is_sindi_index(const ObTableSchema *vec_index_schema);
static int get_index_column_collation_type(const int64_t tenant_id, const uint64_t index_table_id, ObCollationType &col_type);
private:
static void save_column_schema(
const ObColumnSchemaV2 *&old_column,
......
......@@ -90,7 +90,8 @@ ObVectorIndexTabletContext::ObVectorIndexTabletContext()
lob_inrow_threshold_(0), rowkey_cnt_(0), column_cnt_(0), snapshot_version_(0), index_type_(share::VIAT_MAX), helper_(nullptr),
allocator_("VecIndexCtx", OB_MALLOC_NORMAL_BLOCK_SIZE, MTL_ID()),
memory_context_(MTL(ObPluginVectorIndexService *)->get_memory_context()),
all_vsag_use_mem_(MTL(ObPluginVectorIndexService *)->get_all_vsag_use_mem())
all_vsag_use_mem_(MTL(ObPluginVectorIndexService *)->get_all_vsag_use_mem()),
table_id_(0)
{
}
......@@ -119,6 +120,7 @@ int ObVectorIndexTabletContext::init(
column_cnt_ = ddl_table_schema.column_items_.count();
snapshot_version_ = snapshot_version;
ddl_task_id_ = ddl_task_id;
table_id_ = ddl_table_schema.table_id_;
if (schema::is_vec_index_snapshot_data_type(index_type)) {
if (OB_FAIL(init_hnsw_index(ddl_table_schema))) {
......@@ -1774,13 +1776,18 @@ int ObHNSWEmbeddingOperator::init(const ObTabletID &tablet_id)
ret = OB_ERR_UNEXPECTED;
LOG_WARN("error unexpected, vector index ctx is null", K(ret));
} else {
const uint64_t table_id = vector_index_ctx->table_id_;
vec_dim_ = vector_index_ctx->vec_dim_;
rowkey_cnt_ = vector_index_ctx->rowkey_cnt_;
text_col_idx_ = vector_index_ctx->vector_chunk_col_idx_;
extra_column_idxs_.reset();
ObVectorIndexParam index_param;
ObSchemaGetterGuard schema_guard;
ObCollationType col_type = CS_TYPE_INVALID;
if (OB_FAIL(vector_index_ctx->build_extra_column_idxs(static_cast<int32_t>(text_col_idx_), extra_column_idxs_))) {
if (OB_FAIL(ObVectorIndexUtil::get_index_column_collation_type(MTL_ID(), table_id, col_type))) {
LOG_WARN("fail to get vector column collation type", K(ret), K(text_col_idx_), K(table_id));
} else if (OB_FAIL(vector_index_ctx->build_extra_column_idxs(static_cast<int32_t>(text_col_idx_), extra_column_idxs_))) {
LOG_WARN("build_extra_column_idxs failed", K(ret), K(text_col_idx_));
} else if (OB_FAIL(ObVectorIndexUtil::parser_params_from_string(vector_index_ctx->vec_idx_param_, ObVectorIndexType::VIT_HNSW_INDEX, index_param, false))) {
LOG_WARN("failed to parser params from string", K(ret));
......@@ -1797,7 +1804,7 @@ int ObHNSWEmbeddingOperator::init(const ObTabletID &tablet_id)
}
if (OB_SUCC(ret)) {
if (OB_FAIL(embedmgr_->init(model_id_, http_timeout_us_))) {
if (OB_FAIL(embedmgr_->init(model_id_, http_timeout_us_, col_type))) {
embedmgr_->~ObEmbeddingTaskMgr();
op_allocator_.free(embedmgr_);
embedmgr_ = nullptr;
......
......@@ -171,6 +171,7 @@ public:
common::ObArenaAllocator allocator_;
lib::MemoryContext &memory_context_;
uint64_t *all_vsag_use_mem_;
uint64_t table_id_;
};
class ObVectorIndexRowIterator
......
......@@ -556,7 +556,7 @@ ObEmbeddingTaskMgr::~ObEmbeddingTaskMgr()
}
}
int ObEmbeddingTaskMgr::init(const ObString &model_id, const int64_t http_timeout_us)
int ObEmbeddingTaskMgr::init(const ObString &model_id, const int64_t http_timeout_us, const ObCollationType col_type)
{
int ret = OB_SUCCESS;
if (OB_UNLIKELY(is_inited_)) {
......@@ -577,6 +577,7 @@ int ObEmbeddingTaskMgr::init(const ObString &model_id, const int64_t http_timeou
if (OB_SUCC(ret)) {
// TODO(fanfangyao.ffy): 待调参
http_timeout_us_ = http_timeout_us;
cs_type_ = col_type;
const int64_t reserve_slots = ring_capacity_ > 0 ? ring_capacity_ : 5;
if (OB_FAIL(slot_ring_.init(reserve_slots))) {
LOG_WARN("init slot ring failed", K(ret), K(reserve_slots));
......@@ -648,7 +649,7 @@ int ObEmbeddingTaskMgr::submit_batch_info(ObTaskBatchInfo *&batch_info)
task = new (task_mem) share::ObEmbeddingTask();
const int64_t vec_dim = results.at(0)->get_vector_dim();
if (OB_FAIL(task->init(cfg_.model_url_, cfg_.model_name_, cfg_.provider_,
cfg_.user_key_, texts, vec_dim, http_timeout_us_, cb_handle))) {
cfg_.user_key_, texts, cs_type_, vec_dim, http_timeout_us_, cb_handle))) {
LOG_WARN("failed to initialize EmbeddingTask", K(ret));
}
}
......
......@@ -239,9 +239,9 @@ class ObEmbeddingTaskMgr
public:
ObEmbeddingTaskMgr() : allocator_("EmbedTaskMgr", OB_MALLOC_NORMAL_BLOCK_SIZE, MTL_ID()),
embedding_handler_(nullptr), slot_ring_(), ring_capacity_(9),
cfg_(), is_inited_(false), is_failed_(false), http_timeout_us_(0) {}
cfg_(), is_inited_(false), is_failed_(false), http_timeout_us_(0), cs_type_(CS_TYPE_INVALID) {}
~ObEmbeddingTaskMgr();
int init(const common::ObString &model_id, const int64_t http_timeout_us);
int init(const common::ObString &model_id, const int64_t http_timeout_us, const ObCollationType cs_type);
int submit_batch_info(ObTaskBatchInfo *&batch_info);
int get_ready_batch_info(ObTaskBatchInfo *&batch_info, int &error_ret_code);
int mark_task_ready(const int64_t slot_idx, const int ret_code);
......@@ -263,6 +263,7 @@ private:
bool is_inited_;
bool is_failed_;
int64_t http_timeout_us_;
ObCollationType cs_type_;
DISALLOW_COPY_AND_ASSIGN(ObEmbeddingTaskMgr);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册