diff --git a/serve/main.py b/serve/main.py index 8efc4554baa5fc405cee15715e437e1352a6ae88..abdaf5d5166a4564ba4f97021823d0eea8479f73 100644 --- a/serve/main.py +++ b/serve/main.py @@ -22,7 +22,7 @@ if topology_filepath is None: ) with_gpu = os.getenv('WITH_GPU', '0') != '0' - +output_field = os.getenv('OUTPUT_FIELD', 'value') port = int(os.getenv('PORT', '80')) app = Flask(__name__) @@ -55,6 +55,9 @@ def infer(): # threads, so we create a single worker thread. def worker(): paddle.init(use_gpu=with_gpu) + + fields = filter(lambda x: len(x) != 0, output_field.split(",")) + with open(tarfn) as param_f, open(topology_filepath) as topo_f: params = paddle.parameters.Parameters.from_tar(param_f) inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f) @@ -67,7 +70,7 @@ def worker(): for i, key in enumerate(j): d.append(j[key]) feeding[key] = i - r = inferer.infer([d], feeding=feeding) + r = inferer.infer([d], feeding=feeding, field=fields) except: trace = traceback.format_exc() recvQ.put((False, trace))