main.py 2.0 KB
Newer Older
1 2 3 4 5 6
import os
import traceback

import paddle.v2 as paddle
from flask import Flask, jsonify, request
from flask_cors import CORS
7 8
from Queue import Queue
import threading
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

tarfn = os.getenv('PARAMETER_TAR_PATH', None)

if tarfn is None:
    raise ValueError(
        "please specify parameter tar file path with environment variable PARAMETER_TAR_PATH"
    )

topology_filepath = os.getenv('TOPOLOGY_FILE_PATH', None)

if topology_filepath is None:
    raise ValueError(
        "please specify topology file path with environment variable TOPOLOGY_FILE_PATH"
    )

with_gpu = os.getenv('WITH_GPU', '0') != '0'

port = int(os.getenv('PORT', '80'))

app = Flask(__name__)
CORS(app)


def errorResp(msg):
    return jsonify(code=-1, message=msg)


def successResp(data):
    return jsonify(code=0, message="success", data=data)


40 41 42
sendQ = Queue()


43 44
@app.route('/', methods=['POST'])
def infer():
Y
Yu Yang 已提交
45
    recv_queue = Queue()
Y
Yu Yang 已提交
46 47
    sendQ.put((request.json, recv_queue))
    success, resp = recv_queue.get()
48 49 50 51
    if success:
        return successResp(resp)
    else:
        return errorResp(resp)
52 53


54 55 56
# PaddlePaddle v0.10.0 does not support inference from different
# threads, so we create a single worker thread.
def worker():
57 58 59 60
    paddle.init(use_gpu=with_gpu)
    with open(tarfn) as param_f, open(topology_filepath) as topo_f:
        params = paddle.parameters.Parameters.from_tar(param_f)
        inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f)
61 62

    while True:
Y
Yu Yang 已提交
63
        j, recv_queue = sendQ.get()
64 65 66 67 68 69 70 71 72
        try:
            feeding = {}
            d = []
            for i, key in enumerate(j):
                d.append(j[key])
                feeding[key] = i
                r = inferer.infer([d], feeding=feeding)
        except:
            trace = traceback.format_exc()
Y
Yu Yang 已提交
73
            recv_queue.put((False, trace))
74
            continue
Y
Yu Yang 已提交
75
        recv_queue.put((True, r.tolist()))
76 77 78 79 80 81


if __name__ == '__main__':
    t = threading.Thread(target=worker)
    t.daemon = True
    t.start()
82
    print 'serving on port', port
83
    app.run(host='0.0.0.0', port=port, threaded=True)