main.py 2.2 KB
Newer Older
1 2 3 4 5 6
import os
import traceback

import paddle.v2 as paddle
from flask import Flask, jsonify, request
from flask_cors import CORS
7 8
from Queue import Queue
import threading
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24

tarfn = os.getenv('PARAMETER_TAR_PATH', None)

if tarfn is None:
    raise ValueError(
        "please specify parameter tar file path with environment variable PARAMETER_TAR_PATH"
    )

topology_filepath = os.getenv('TOPOLOGY_FILE_PATH', None)

if topology_filepath is None:
    raise ValueError(
        "please specify topology file path with environment variable TOPOLOGY_FILE_PATH"
    )

with_gpu = os.getenv('WITH_GPU', '0') != '0'
25
output_field = os.getenv('OUTPUT_FIELD', 'value')
26 27 28 29 30 31 32 33 34 35 36 37 38 39
port = int(os.getenv('PORT', '80'))

app = Flask(__name__)
CORS(app)


def errorResp(msg):
    return jsonify(code=-1, message=msg)


def successResp(data):
    return jsonify(code=0, message="success", data=data)


40 41 42 43
sendQ = Queue()
recvQ = Queue()


44 45
@app.route('/', methods=['POST'])
def infer():
46 47 48 49 50 51
    sendQ.put(request.json)
    success, resp = recvQ.get()
    if success:
        return successResp(resp)
    else:
        return errorResp(resp)
52 53


54 55 56
# PaddlePaddle v0.10.0 does not support inference from different
# threads, so we create a single worker thread.
def worker():
57
    paddle.init(use_gpu=with_gpu)
58 59 60

    fields = filter(lambda x: len(x) != 0, output_field.split(","))

61 62 63
    with open(tarfn) as param_f, open(topology_filepath) as topo_f:
        params = paddle.parameters.Parameters.from_tar(param_f)
        inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f)
64 65 66 67 68 69 70 71 72

    while True:
        j = sendQ.get()
        try:
            feeding = {}
            d = []
            for i, key in enumerate(j):
                d.append(j[key])
                feeding[key] = i
73
                r = inferer.infer([d], feeding=feeding, field=fields)
74 75 76 77
        except:
            trace = traceback.format_exc()
            recvQ.put((False, trace))
            continue
Y
Yu Yang 已提交
78 79 80 81
        if isinstance(r, list):
            recvQ.put((True, [elem.tolist() for elem in r]))
        else:
            recvQ.put((True, r.tolist()))
82 83 84 85 86 87


if __name__ == '__main__':
    t = threading.Thread(target=worker)
    t.daemon = True
    t.start()
88
    print 'serving on port', port
89
    app.run(host='0.0.0.0', port=port, threaded=True)