Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
mania_yan
Triton Bert
提交
98ed6810
T
Triton Bert
项目概览
mania_yan
/
Triton Bert
通知
4
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Triton Bert
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
98ed6810
编写于
12月 07, 2021
作者:
Y
yanyongwen712
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add necessary files
上级
c1b461a7
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
132 addition
and
0 deletion
+132
-0
__init__.py
__init__.py
+0
-0
examples/chitchat.py
examples/chitchat.py
+21
-0
triton_bert.py
triton_bert.py
+111
-0
未找到文件。
__init__.py
0 → 100644
浏览文件 @
98ed6810
examples/chitchat.py
0 → 100644
浏览文件 @
98ed6810
from
triton_bert
import
TritonBert
import
torch.nn.functional
as
F
import
torch
class
ChitchatIntentDetection
(
TritonBert
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
label_list
=
[
"闲聊"
,
"问答"
,
"扯淡"
]
def
proprocess
(
self
,
triton_output
):
logits
=
triton_output
[
0
]
label_ids
=
logits
.
argmax
(
axis
=-
1
)
logits
=
torch
.
tensor
(
logits
)
probs
=
F
.
softmax
(
logits
,
dim
=
1
).
numpy
()
ret
=
[]
for
i
,
label_id
in
enumerate
(
label_ids
):
prob
=
probs
[
i
][
label_id
]
if
label_id
==
2
and
prob
<
0.8
:
label_id
=
0
ret
.
append
({
"category"
:
self
.
label_list
[
label_id
],
"confidence"
:
float
(
prob
)})
return
ret
\ No newline at end of file
triton_bert.py
0 → 100644
浏览文件 @
98ed6810
from
typing
import
Optional
import
tritonclient.grpc
as
grpcclient
from
more_itertools.more
import
chunked
from
transformers
import
BertTokenizerFast
from
transformers.file_utils
import
PaddingStrategy
from
transformers.tokenization_utils_base
import
TruncationStrategy
from
tritonclient.utils
import
InferenceServerException
,
triton_to_np_dtype
class
TritonBert
:
def
__init__
(
self
,
model
:
str
,
vocab
:
str
,
triton_host
:
str
=
"localhost"
,
triton_grpc_port
:
int
=
8001
,
model_max_len
:
int
=
512
,
padding
:
Optional
[
PaddingStrategy
]
=
None
,
truncation
:
TruncationStrategy
=
TruncationStrategy
.
LONGEST_FIRST
):
self
.
triton_url
=
f
"
{
triton_host
}
:
{
triton_grpc_port
}
"
self
.
connect_triton
()
self
.
model
=
model
self
.
model_max_len
=
model_max_len
self
.
padding
=
padding
self
.
truncation
=
truncation
self
.
parse_triton_model_config
()
self
.
tokenizer
=
BertTokenizerFast
.
from_pretrained
(
vocab
)
def
connect_triton
(
self
):
self
.
triton_client
=
grpcclient
.
InferenceServerClient
(
url
=
self
.
triton_url
)
def
parse_triton_model_config
(
self
):
model_config
=
self
.
triton_client
.
get_model_config
(
self
.
model
,
as_json
=
True
)
self
.
model_config
=
model_config
self
.
max_batch_size
=
int
(
model_config
[
'config'
].
get
(
'max_batch_size'
,
1
))
#assume the input dim is (sequence) or (batch, sequence)
self
.
sequence_len
=
int
(
model_config
[
'config'
][
'input'
][
0
][
'dims'
][
-
1
])
self
.
dynamic_sequence
=
True
if
self
.
sequence_len
==
-
1
else
False
if
not
self
.
padding
:
self
.
padding
=
PaddingStrategy
.
LONGEST
if
self
.
dynamic_sequence
else
PaddingStrategy
.
MAX_LENGTH
#assume dynamic sequence is from [0, 512] typically
self
.
model_max_len
=
self
.
model_max_len
if
self
.
dynamic_sequence
else
self
.
sequence_len
self
.
model_max_sequence_len
=
self
.
model_max_len
self
.
model_input_names
=
[
_input
[
'name'
]
for
_input
in
model_config
[
'config'
][
'input'
]]
self
.
model_input_data_types
=
[
_input
[
'data_type'
].
replace
(
"TYPE_"
,
""
)
for
_input
in
model_config
[
'config'
][
'input'
]]
self
.
model_output_names
=
[
_output
[
'name'
]
for
_output
in
model_config
[
'config'
][
'output'
]]
def
triton_infer
(
self
,
encoded_input
):
if
not
encoded_input
:
return
None
batch
=
len
(
encoded_input
[
'input_ids'
])
if
self
.
padding
==
PaddingStrategy
.
MAX_LENGTH
:
max_sequence_len
=
self
.
model_max_len
else
:
max_sequence_len
=
len
(
max
(
encoded_input
[
'input_ids'
],
key
=
lambda
x
:
len
(
x
)))
#assume bert model input dim is (batch, sequence)
inputs
=
[
grpcclient
.
InferInput
(
input_name
,
[
batch
,
max_sequence_len
],
data_type
)
for
input_name
,
data_type
in
zip
(
self
.
model_input_names
,
self
.
model_input_data_types
)]
outputs
=
[
grpcclient
.
InferRequestedOutput
(
output_name
)
for
output_name
in
self
.
model_output_names
]
#bert: ['input_ids', 'attention_mask', 'token_type_ids']
#TODO: use encoded_input.keys(). key order ??
for
i
,
k
in
enumerate
([
'input_ids'
,
'attention_mask'
,
'token_type_ids'
]):
#logger.debug(f"{encoded_input[k]}")
inputs
[
i
].
set_data_from_numpy
(
encoded_input
[
k
].
astype
(
triton_to_np_dtype
(
self
.
model_input_data_types
[
i
])))
try
:
triton_ret
=
self
.
triton_client
.
infer
(
model_name
=
self
.
model
,
inputs
=
inputs
,
outputs
=
outputs
)
except
InferenceServerException
as
error
:
#if triton restart, we will miss the connection. so, we will trigger to reconnect again
self
.
connect_triton
()
triton_ret
=
self
.
triton_client
.
infer
(
model_name
=
self
.
model
,
inputs
=
inputs
,
outputs
=
outputs
)
return
[
triton_ret
.
as_numpy
(
output_name
)
for
output_name
in
self
.
model_output_names
]
def
preprocess
(
self
,
texts
,
text_pairs
=
[]):
if
not
texts
:
return
if
text_pairs
:
encoded_input
=
self
.
tokenizer
(
text
=
texts
,
text_pair
=
text_pairs
,
padding
=
self
.
padding
,
\
truncation
=
self
.
truncation
,
max_length
=
self
.
model_max_len
,
return_tensors
=
'np'
)
else
:
encoded_input
=
self
.
tokenizer
(
text
=
texts
,
padding
=
self
.
padding
,
\
truncation
=
self
.
truncation
,
max_length
=
self
.
model_max_len
,
return_tensors
=
'np'
)
return
encoded_input
def
proprocess
(
self
,
triton_output
):
raise
NotImplementedError
def
_predict
(
self
,
texts
,
text_pairs
=
[]):
if
not
texts
:
return
[]
encoded_input
=
self
.
preprocess
(
texts
,
text_pairs
)
if
not
encoded_input
:
return
[]
triton_output
=
self
.
triton_infer
(
encoded_input
)
if
not
triton_output
:
return
[]
return
self
.
proprocess
(
triton_output
)
def
predict
(
self
,
texts
,
text_pairs
=
[]):
outputs
=
[]
if
text_pairs
:
for
_texts
,
_text_pairs
in
zip
(
chunked
(
texts
,
self
.
max_batch_size
),
chunked
(
text_pairs
,
self
.
max_batch_size
)):
outputs
.
extend
(
self
.
_predict
(
_texts
,
_text_pairs
))
else
:
for
_texts
in
chunked
(
texts
,
self
.
max_batch_size
):
outputs
.
extend
(
self
.
_predict
(
_texts
))
return
outputs
def
__call__
(
self
,
texts
:
list
,
text_pairs
:
list
=
[]):
if
isinstance
(
texts
,
str
):
texts
=
[
texts
]
return
self
.
predict
(
texts
,
text_pairs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录