提交 65eef023 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #414 from reyoung/feature/outputs_predict

Support output field for inference server
......@@ -22,7 +22,7 @@ if topology_filepath is None:
)
with_gpu = os.getenv('WITH_GPU', '0') != '0'
output_field = os.getenv('OUTPUT_FIELD', 'value')
port = int(os.getenv('PORT', '80'))
app = Flask(__name__)
......@@ -55,6 +55,9 @@ def infer():
# threads, so we create a single worker thread.
def worker():
paddle.init(use_gpu=with_gpu)
fields = filter(lambda x: len(x) != 0, output_field.split(","))
with open(tarfn) as param_f, open(topology_filepath) as topo_f:
params = paddle.parameters.Parameters.from_tar(param_f)
inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f)
......@@ -67,12 +70,15 @@ def worker():
for i, key in enumerate(j):
d.append(j[key])
feeding[key] = i
r = inferer.infer([d], feeding=feeding)
r = inferer.infer([d], feeding=feeding, field=fields)
except:
trace = traceback.format_exc()
recv_queue.put((False, trace))
continue
recv_queue.put((True, r.tolist()))
if isinstance(r, list):
recv_queue.put((True, [elem.tolist() for elem in r]))
else:
recv_queue.put((True, r.tolist()))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册