提交 8034ec02 编写于 作者: Y Yu Yang

Support output field for inference server

上级 b3754c77
...@@ -22,7 +22,7 @@ if topology_filepath is None: ...@@ -22,7 +22,7 @@ if topology_filepath is None:
) )
with_gpu = os.getenv('WITH_GPU', '0') != '0' with_gpu = os.getenv('WITH_GPU', '0') != '0'
output_field = os.getenv('OUTPUT_FIELD', 'value')
port = int(os.getenv('PORT', '80')) port = int(os.getenv('PORT', '80'))
app = Flask(__name__) app = Flask(__name__)
...@@ -55,6 +55,9 @@ def infer(): ...@@ -55,6 +55,9 @@ def infer():
# threads, so we create a single worker thread. # threads, so we create a single worker thread.
def worker(): def worker():
paddle.init(use_gpu=with_gpu) paddle.init(use_gpu=with_gpu)
fields = filter(lambda x: len(x) != 0, output_field.split(","))
with open(tarfn) as param_f, open(topology_filepath) as topo_f: with open(tarfn) as param_f, open(topology_filepath) as topo_f:
params = paddle.parameters.Parameters.from_tar(param_f) params = paddle.parameters.Parameters.from_tar(param_f)
inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f) inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f)
...@@ -67,7 +70,7 @@ def worker(): ...@@ -67,7 +70,7 @@ def worker():
for i, key in enumerate(j): for i, key in enumerate(j):
d.append(j[key]) d.append(j[key])
feeding[key] = i feeding[key] = i
r = inferer.infer([d], feeding=feeding) r = inferer.infer([d], feeding=feeding, field=fields)
except: except:
trace = traceback.format_exc() trace = traceback.format_exc()
recvQ.put((False, trace)) recvQ.put((False, trace))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册