Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
36918628
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
36918628
编写于
9月 11, 2019
作者:
Y
Yu Kun
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/branch-0.4.0' into branch-0.4.0
Former-commit-id: 1655e5fd449246482b0fc742a1177aa058561c58
上级
3ee1c75e
5d6bd172
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
8 addition
and
8 deletion
+8
-8
cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp
cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp
+8
-8
未找到文件。
cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp
浏览文件 @
36918628
...
...
@@ -26,17 +26,17 @@ namespace knowhere {
IndexModelPtr
GPUIVF
::
Train
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
auto
nlist
=
config
[
"nlist"
].
as
<
size_t
>
();
auto
gpu_device
=
config
.
get_with_default
(
"gpu_id"
,
gpu_id_
);
gpu_id_
=
config
.
get_with_default
(
"gpu_id"
,
gpu_id_
);
auto
metric_type
=
config
[
"metric_type"
].
as_string
()
==
"L2"
?
faiss
::
METRIC_L2
:
faiss
::
METRIC_INNER_PRODUCT
;
GETTENSOR
(
dataset
)
auto
temp_resource
=
FaissGpuResourceMgr
::
GetInstance
().
GetRes
(
gpu_
device
);
auto
temp_resource
=
FaissGpuResourceMgr
::
GetInstance
().
GetRes
(
gpu_
id_
);
if
(
temp_resource
!=
nullptr
)
{
ResScope
rs
(
gpu_
device
,
temp_resource
);
ResScope
rs
(
gpu_
id_
,
temp_resource
);
faiss
::
gpu
::
GpuIndexIVFFlatConfig
idx_config
;
idx_config
.
device
=
gpu_
device
;
idx_config
.
device
=
gpu_
id_
;
faiss
::
gpu
::
GpuIndexIVFFlat
device_index
(
temp_resource
->
faiss_res
.
get
(),
dim
,
nlist
,
metric_type
,
idx_config
);
device_index
.
train
(
rows
,
(
float
*
)
p_data
);
...
...
@@ -204,7 +204,7 @@ VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
IndexModelPtr
GPUIVFSQ
::
Train
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
auto
nlist
=
config
[
"nlist"
].
as
<
size_t
>
();
auto
nbits
=
config
[
"nbits"
].
as
<
size_t
>
();
// TODO(linxj): gpu only support SQ4 SQ8 SQ16
auto
gpu_num
=
config
.
get_with_default
(
"gpu_id"
,
gpu_id_
);
gpu_id_
=
config
.
get_with_default
(
"gpu_id"
,
gpu_id_
);
auto
metric_type
=
config
[
"metric_type"
].
as_string
()
==
"L2"
?
faiss
::
METRIC_L2
:
faiss
::
METRIC_INNER_PRODUCT
;
...
...
@@ -214,10 +214,10 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
index_type
<<
"IVF"
<<
nlist
<<
","
<<
"SQ"
<<
nbits
;
auto
build_index
=
faiss
::
index_factory
(
dim
,
index_type
.
str
().
c_str
(),
metric_type
);
auto
temp_resource
=
FaissGpuResourceMgr
::
GetInstance
().
GetRes
(
gpu_
num
);
auto
temp_resource
=
FaissGpuResourceMgr
::
GetInstance
().
GetRes
(
gpu_
id_
);
if
(
temp_resource
!=
nullptr
)
{
ResScope
rs
(
gpu_
num
,
temp_resource
);
auto
device_index
=
faiss
::
gpu
::
index_cpu_to_gpu
(
temp_resource
->
faiss_res
.
get
(),
gpu_
num
,
build_index
);
ResScope
rs
(
gpu_
id_
,
temp_resource
);
auto
device_index
=
faiss
::
gpu
::
index_cpu_to_gpu
(
temp_resource
->
faiss_res
.
get
(),
gpu_
id_
,
build_index
);
device_index
->
train
(
rows
,
(
float
*
)
p_data
);
std
::
shared_ptr
<
faiss
::
Index
>
host_index
=
nullptr
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录