text_classification.py 1.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
# -*- coding: utf-8 -*-
import json
import httplib
import sys
import os

BATCH_SIZE = 10


def data_reader(data_file, samples, labels):
    if not os.path.exists(data_file):
        print "Path %s not exist" % data_file
        return -1

    with open(data_file, "r") as f:
        for line in f:
            line = line.replace('(', ' ')
            line = line.replace(')', ' ')
            line = line.replace('[', ' ')
            line = line.replace(']', ' ')
            ids = line.split(',')
            ids = [int(x) for x in ids]
            label = ids[-1]
            ids = ids[0:-1]
            samples.append(ids)
            labels.append(label)


if __name__ == "__main__":
    """ main
    """
    if len(sys.argv) != 2:
        print "Usage: python text_classification.py DATA_FILE"
        sys.exit(-1)

    samples = []
    labels = []
    ret = data_reader(sys.argv[1], samples, labels)

    conn = httplib.HTTPConnection("127.0.0.1", 8010)
    # conn.putheader('Content-Type', 'application/json') # 如果server版本在r31987后,不需要设置这个。

    for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
        batch = samples[i:i + BATCH_SIZE]
        ids = []
        for x in batch:
            ids.append({"ids": x})
        ids = {"instances": ids}

        request_json = json.dumps(ids)

        try:
            conn.request('POST', "/TextClassificationService/inference",
                         request_json, {"Content-Type": "application/json"})
            response = conn.getresponse()
            print response.read()
        except httplib.HTTPException as e:
            print e.reason