1.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
import os
import traceback

import paddle.v2 as paddle
from flask import Flask, jsonify, request
from flask_cors import CORS

tarfn = os.getenv('PARAMETER_TAR_PATH', None)

if tarfn is None:
    raise ValueError(
        "please specify parameter tar file path with environment variable PARAMETER_TAR_PATH"

topology_filepath = os.getenv('TOPOLOGY_FILE_PATH', None)

if topology_filepath is None:
    raise ValueError(
        "please specify topology file path with environment variable TOPOLOGY_FILE_PATH"

with_gpu = os.getenv('WITH_GPU', '0') != '0'

port = int(os.getenv('PORT', '80'))

app = Flask(__name__)

def errorResp(msg):
    return jsonify(code=-1, message=msg)

def successResp(data):
    return jsonify(code=0, message="success", data=data)

@app.route('/', methods=['POST'])
def infer():
    global inferer
        feeding = {}
        d = []
        for i, key in enumerate(request.json):
            feeding[key] = i
        r = inferer.infer([d], feeding=feeding)
        trace = traceback.format_exc()
        return errorResp(trace)
    return successResp(r.tolist())

if __name__ == '__main__':
    with open(tarfn) as param_f, open(topology_filepath) as topo_f:
        params = paddle.parameters.Parameters.from_tar(param_f)
        inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f)
    print 'serving on port', port'', port=port, threaded=True)