...
 
Commits (3)
    https://gitcode.net/yyw794/triton_bert/-/commit/dce0756d220644e540b0b119fe743e0e5bb15e8b version 0.0.2. __call__ support str type input 2021-12-07T16:42:46+08:00 yanyongwen712 yanyongwen712@pingan.com.cn https://gitcode.net/yyw794/triton_bert/-/commit/c1b461a7cd7a30cfa871a0be98a5ffac12457f38 update examples; add code link 2021-12-07T16:53:34+08:00 yanyongwen712 yanyongwen712@pingan.com.cn https://gitcode.net/yyw794/triton_bert/-/commit/98ed6810962a61b5e3c614fc053a5a9d59441e61 add necessary files 2021-12-07T16:56:08+08:00 yanyongwen712 yanyongwen712@pingan.com.cn
It is easy to use bert in triton now. It is easy to use bert in triton now.
Algorithm Engineer only need to focus to write proprocess function to make his model work. Algorithm Engineer only need to focus to write proprocess function to make his model work.
pls see examples
[code](https://codechina.csdn.net/yyw794/triton_bert)
from triton_bert.triton_bert import TritonBert from triton_bert import TritonBert
import numpy as np import numpy as np
class Biencoder(TritonBert): class Biencoder(TritonBert):
''' '''
this is sentence sbert whose vector will be stored in milvus this is sentence sbert whose vector will be stored in milvus
''' '''
def __init__(self, model="sbert", vocab="./examples/config/ernie"): def __init__(self, **kwargs):
super().__init__(model, vocab) super().__init__(**kwargs)
self.normalize_vector = False self.normalize_vector = False
def proprocess(self, triton_output): def proprocess(self, triton_output):
......
from triton_bert import TritonBert
import torch.nn.functional as F
import torch
class ChitchatIntentDetection(TritonBert):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.label_list = ["闲聊", "问答", "扯淡"]
def proprocess(self, triton_output):
logits = triton_output[0]
label_ids = logits.argmax(axis=-1)
logits = torch.tensor(logits)
probs = F.softmax(logits, dim=1).numpy()
ret = []
for i, label_id in enumerate(label_ids):
prob = probs[i][label_id]
if label_id == 2 and prob < 0.8:
label_id = 0
ret.append({"category": self.label_list[label_id], "confidence": float(prob)})
return ret
\ No newline at end of file
from triton_bert.triton_bert import TritonBert from triton_bert import TritonBert
import numpy as np import numpy as np
class CrossEncoder(TritonBert): class CrossEncoder(TritonBert):
''' '''
rank with text similarity rank with text similarity
''' '''
def __init__(self, model="rank", vocab="./examples/config/ernie"): def __init__(self, **kwargs):
super().__init__(model, vocab) super().__init__(**kwargs)
def proprocess(self, triton_output): def proprocess(self, triton_output):
return np.squeeze(triton_output[0], axis=1).tolist() return np.squeeze(triton_output[0], axis=1).tolist()
......
...@@ -4,8 +4,8 @@ with open("README.md", "r") as fh: ...@@ -4,8 +4,8 @@ with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
setuptools.setup( setuptools.setup(
name="triton_bert", name="triton-bert",
version="0.0.1", version="0.0.2",
author="Yongwen Yan", author="Yongwen Yan",
author_email="yyw794@126.com", author_email="yyw794@126.com",
description="easy to use bert with nvidia triton server", description="easy to use bert with nvidia triton server",
...@@ -13,10 +13,10 @@ setuptools.setup( ...@@ -13,10 +13,10 @@ setuptools.setup(
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://codechina.csdn.net/yyw794/triton_bert", url="https://codechina.csdn.net/yyw794/triton_bert",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
install_requires=['tritonclient', 'transformers', 'more-itertools'], install_requires=['tritonclient[all]', 'transformers', 'more-itertools'],
classifiers=( classifiers=(
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
), ),
) )
\ No newline at end of file
...@@ -104,6 +104,8 @@ class TritonBert: ...@@ -104,6 +104,8 @@ class TritonBert:
outputs.extend(self._predict(_texts)) outputs.extend(self._predict(_texts))
return outputs return outputs
def __call__(self, texts, text_pairs=[]): def __call__(self, texts: list, text_pairs: list=[]):
if isinstance(texts, str):
texts = [texts]
return self.predict(texts, text_pairs) return self.predict(texts, text_pairs)