Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
mania_yan
Triton Bert
比较版本
09545f01b08246d2d8079984a2a45aa7dee378cb...98ed6810962a61b5e3c614fc053a5a9d59441e61
T
Triton Bert
项目概览
mania_yan
/
Triton Bert
通知
4
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Triton Bert
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
98ed6810962a61b5e3c614fc053a5a9d59441e61
选择Git版本
...
目标分支
09545f01b08246d2d8079984a2a45aa7dee378cb
选择Git版本
比较
Commits (3)
https://gitcode.net/yyw794/triton_bert/-/commit/dce0756d220644e540b0b119fe743e0e5bb15e8b
version 0.0.2. __call__ support str type input
2021-12-07T16:42:46+08:00
yanyongwen712
yanyongwen712@pingan.com.cn
https://gitcode.net/yyw794/triton_bert/-/commit/c1b461a7cd7a30cfa871a0be98a5ffac12457f38
update examples; add code link
2021-12-07T16:53:34+08:00
yanyongwen712
yanyongwen712@pingan.com.cn
https://gitcode.net/yyw794/triton_bert/-/commit/98ed6810962a61b5e3c614fc053a5a9d59441e61
add necessary files
2021-12-07T16:56:08+08:00
yanyongwen712
yanyongwen712@pingan.com.cn
隐藏空白更改
内联
并排
Showing
7 changed file
with
40 addition
and
11 deletion
+40
-11
README.md
README.md
+6
-0
__init__.py
__init__.py
+0
-0
examples/biencoder.py
examples/biencoder.py
+3
-3
examples/chitchat.py
examples/chitchat.py
+21
-0
examples/crossencoder.py
examples/crossencoder.py
+3
-3
setup.py
setup.py
+4
-4
triton_bert.py
triton_bert.py
+3
-1
未找到文件。
README.md
浏览文件 @
98ed6810
It is easy to use bert in triton now.
It is easy to use bert in triton now.
Algorithm Engineer only need to focus to write proprocess function to make his model work.
Algorithm Engineer only need to focus to write proprocess function to make his model work.
pls see examples
[
code
](
https://codechina.csdn.net/yyw794/triton_bert
)
triton_bert/
__init__.py
→
__init__.py
浏览文件 @
98ed6810
文件已移动
examples/biencoder.py
浏览文件 @
98ed6810
from
triton_bert
.triton_bert
import
TritonBert
from
triton_bert
import
TritonBert
import
numpy
as
np
import
numpy
as
np
class
Biencoder
(
TritonBert
):
class
Biencoder
(
TritonBert
):
'''
'''
this is sentence sbert whose vector will be stored in milvus
this is sentence sbert whose vector will be stored in milvus
'''
'''
def
__init__
(
self
,
model
=
"sbert"
,
vocab
=
"./examples/config/ernie"
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
model
,
vocab
)
super
().
__init__
(
**
kwargs
)
self
.
normalize_vector
=
False
self
.
normalize_vector
=
False
def
proprocess
(
self
,
triton_output
):
def
proprocess
(
self
,
triton_output
):
...
...
examples/chitchat.py
0 → 100644
浏览文件 @
98ed6810
from
triton_bert
import
TritonBert
import
torch.nn.functional
as
F
import
torch
class
ChitchatIntentDetection
(
TritonBert
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
label_list
=
[
"闲聊"
,
"问答"
,
"扯淡"
]
def
proprocess
(
self
,
triton_output
):
logits
=
triton_output
[
0
]
label_ids
=
logits
.
argmax
(
axis
=-
1
)
logits
=
torch
.
tensor
(
logits
)
probs
=
F
.
softmax
(
logits
,
dim
=
1
).
numpy
()
ret
=
[]
for
i
,
label_id
in
enumerate
(
label_ids
):
prob
=
probs
[
i
][
label_id
]
if
label_id
==
2
and
prob
<
0.8
:
label_id
=
0
ret
.
append
({
"category"
:
self
.
label_list
[
label_id
],
"confidence"
:
float
(
prob
)})
return
ret
\ No newline at end of file
examples/crossencoder.py
浏览文件 @
98ed6810
from
triton_bert
.triton_bert
import
TritonBert
from
triton_bert
import
TritonBert
import
numpy
as
np
import
numpy
as
np
class
CrossEncoder
(
TritonBert
):
class
CrossEncoder
(
TritonBert
):
'''
'''
rank with text similarity
rank with text similarity
'''
'''
def
__init__
(
self
,
model
=
"rank"
,
vocab
=
"./examples/config/ernie"
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
model
,
vocab
)
super
().
__init__
(
**
kwargs
)
def
proprocess
(
self
,
triton_output
):
def
proprocess
(
self
,
triton_output
):
return
np
.
squeeze
(
triton_output
[
0
],
axis
=
1
).
tolist
()
return
np
.
squeeze
(
triton_output
[
0
],
axis
=
1
).
tolist
()
...
...
setup.py
浏览文件 @
98ed6810
...
@@ -4,8 +4,8 @@ with open("README.md", "r") as fh:
...
@@ -4,8 +4,8 @@ with open("README.md", "r") as fh:
long_description
=
fh
.
read
()
long_description
=
fh
.
read
()
setuptools
.
setup
(
setuptools
.
setup
(
name
=
"triton
_
bert"
,
name
=
"triton
-
bert"
,
version
=
"0.0.
1
"
,
version
=
"0.0.
2
"
,
author
=
"Yongwen Yan"
,
author
=
"Yongwen Yan"
,
author_email
=
"yyw794@126.com"
,
author_email
=
"yyw794@126.com"
,
description
=
"easy to use bert with nvidia triton server"
,
description
=
"easy to use bert with nvidia triton server"
,
...
@@ -13,10 +13,10 @@ setuptools.setup(
...
@@ -13,10 +13,10 @@ setuptools.setup(
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://codechina.csdn.net/yyw794/triton_bert"
,
url
=
"https://codechina.csdn.net/yyw794/triton_bert"
,
packages
=
setuptools
.
find_packages
(),
packages
=
setuptools
.
find_packages
(),
install_requires
=
[
'tritonclient'
,
'transformers'
,
'more-itertools'
],
install_requires
=
[
'tritonclient
[all]
'
,
'transformers'
,
'more-itertools'
],
classifiers
=
(
classifiers
=
(
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
"Operating System :: OS Independent"
,
),
),
)
)
\ No newline at end of file
triton_bert
/triton_bert
.py
→
triton_bert.py
浏览文件 @
98ed6810
...
@@ -104,6 +104,8 @@ class TritonBert:
...
@@ -104,6 +104,8 @@ class TritonBert:
outputs
.
extend
(
self
.
_predict
(
_texts
))
outputs
.
extend
(
self
.
_predict
(
_texts
))
return
outputs
return
outputs
def
__call__
(
self
,
texts
,
text_pairs
=
[]):
def
__call__
(
self
,
texts
:
list
,
text_pairs
:
list
=
[]):
if
isinstance
(
texts
,
str
):
texts
=
[
texts
]
return
self
.
predict
(
texts
,
text_pairs
)
return
self
.
predict
(
texts
,
text_pairs
)