bert_client.py 5.5 KB
Newer Older
M
MRXLT 已提交
1
# coding:utf-8
M
MRXLT 已提交
2
import os
M
MRXLT 已提交
3 4 5 6 7
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
M
MRXLT 已提交
8
import time
M
MRXLT 已提交
9 10 11
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
G
guru4elephant 已提交
12 13 14
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
args = benchmark_args()
M
MRXLT 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32

_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)

if is_py2:
    import httplib
if is_py3:
    import http.client as httplib


class BertService():
    def __init__(self,
                 max_seq_len=128,
                 model_name="bert_uncased_L-12_H-768_A-12",
                 show_ids=False,
                 do_lower_case=True,
                 process_id=0,
M
MRXLT 已提交
33
                 retry=3):
M
MRXLT 已提交
34 35 36 37 38 39 40 41
        self.process_id = process_id
        self.reader_flag = False
        self.batch_size = 0
        self.max_seq_len = max_seq_len
        self.model_name = model_name
        self.show_ids = show_ids
        self.do_lower_case = do_lower_case
        self.retry = retry
M
MRXLT 已提交
42
        self.pid = os.getpid()
M
MRXLT 已提交
43 44
        self.profile = True if ("FLAGS_profile_client" in os.environ and
                                os.environ["FLAGS_profile_client"]) else False
M
MRXLT 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

        module = hub.Module(name=self.model_name)
        inputs, outputs, program = module.context(
            trainable=True, max_seq_len=self.max_seq_len)
        input_ids = inputs["input_ids"]
        position_ids = inputs["position_ids"]
        segment_ids = inputs["segment_ids"]
        input_mask = inputs["input_mask"]
        self.reader = hub.reader.ClassifyReader(
            vocab_path=module.get_vocab_path(),
            dataset=None,
            max_seq_len=self.max_seq_len,
            do_lower_case=self.do_lower_case)
        self.reader_flag = True

    def load_client(self, config_file, server_addr):
        self.client = Client()
        self.client.load_client_config(config_file)
        self.client.connect(server_addr)

    def run_general(self, text, fetch):
        self.batch_size = len(text)
        data_generator = self.reader.data_generator(
            batch_size=self.batch_size, phase='predict', data=text)
        result = []
M
MRXLT 已提交
70
        prepro_start = time.time()
M
MRXLT 已提交
71 72 73 74 75 76 77 78 79 80 81 82
        for run_step, batch in enumerate(data_generator(), start=1):
            token_list = batch[0][0].reshape(-1).tolist()
            pos_list = batch[0][1].reshape(-1).tolist()
            sent_list = batch[0][2].reshape(-1).tolist()
            mask_list = batch[0][3].reshape(-1).tolist()
            for si in range(self.batch_size):
                feed = {
                    "input_ids": token_list,
                    "position_ids": pos_list,
                    "segment_ids": sent_list,
                    "input_mask": mask_list
                }
M
MRXLT 已提交
83 84
                prepro_end = time.time()
                if self.profile:
M
MRXLT 已提交
85 86
                    print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                        self.pid,
M
MRXLT 已提交
87 88
                        int(round(prepro_start * 1000000)),
                        int(round(prepro_end * 1000000))))
M
MRXLT 已提交
89 90 91 92 93 94 95 96 97
                fetch_map = self.client.predict(feed=feed, fetch=fetch)

        return fetch_map

    def run_batch_general(self, text, fetch):
        self.batch_size = len(text)
        data_generator = self.reader.data_generator(
            batch_size=self.batch_size, phase='predict', data=text)
        result = []
M
MRXLT 已提交
98
        prepro_start = time.time()
M
MRXLT 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        for run_step, batch in enumerate(data_generator(), start=1):
            token_list = batch[0][0].reshape(-1).tolist()
            pos_list = batch[0][1].reshape(-1).tolist()
            sent_list = batch[0][2].reshape(-1).tolist()
            mask_list = batch[0][3].reshape(-1).tolist()
            feed_batch = []
            for si in range(self.batch_size):
                feed = {
                    "input_ids": token_list[si * self.max_seq_len:(si + 1) *
                                            self.max_seq_len],
                    "position_ids":
                    pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
                    "segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
                                             self.max_seq_len],
                    "input_mask":
                    mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
                }
                feed_batch.append(feed)
M
MRXLT 已提交
117 118
            prepro_end = time.time()
            if self.profile:
M
MRXLT 已提交
119 120
                print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                    self.pid,
M
MRXLT 已提交
121 122
                    int(round(prepro_start * 1000000)),
                    int(round(prepro_end * 1000000))))
M
MRXLT 已提交
123 124 125 126 127
            fetch_map_batch = self.client.batch_predict(
                feed_batch=feed_batch, fetch=fetch)
        return fetch_map_batch


G
guru4elephant 已提交
128
def single_func(idx, resource):
M
MRXLT 已提交
129
    bc = BertService(
M
MRXLT 已提交
130
        model_name='bert_chinese_L-12_H-768_A-12',
M
MRXLT 已提交
131 132 133 134 135
        max_seq_len=20,
        show_ids=False,
        do_lower_case=True)
    config_file = './serving_client_conf/serving_client_conf.prototxt'
    fetch = ["pooled_output"]
G
guru4elephant 已提交
136
    server_addr = [resource["endpoint"][idx]]
M
MRXLT 已提交
137
    bc.load_client(config_file, server_addr)
M
MRXLT 已提交
138
    batch_size = 1
G
guru4elephant 已提交
139 140 141 142 143 144
    start = time.time()
    fin = open("data-c.txt")
    for line in fin:
        result = bc.run_general([[line.strip()]], fetch)
    end = time.time()
    return [[end - start]]
M
MRXLT 已提交
145 146

if __name__ == '__main__':
G
guru4elephant 已提交
147 148 149 150 151
    multi_thread_runner = MultiThreadRunner()
    result = multi_thread_runner.run(single_func, args.thread, {"endpoint":["127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", "127.0.0.1:9497"]})