提交 8034ec02 编写于 作者: Y Yu Yang

Support output field for inference server

上级 b3754c77
......@@ -22,7 +22,7 @@ if topology_filepath is None:
)
with_gpu = os.getenv('WITH_GPU', '0') != '0'
output_field = os.getenv('OUTPUT_FIELD', 'value')
port = int(os.getenv('PORT', '80'))
app = Flask(__name__)
......@@ -55,6 +55,9 @@ def infer():
# threads, so we create a single worker thread.
def worker():
paddle.init(use_gpu=with_gpu)
fields = filter(lambda x: len(x) != 0, output_field.split(","))
with open(tarfn) as param_f, open(topology_filepath) as topo_f:
params = paddle.parameters.Parameters.from_tar(param_f)
inferer = paddle.inference.Inference(parameters=params, fileobj=topo_f)
......@@ -67,7 +70,7 @@ def worker():
for i, key in enumerate(j):
d.append(j[key])
feeding[key] = i
r = inferer.infer([d], feeding=feeding)
r = inferer.infer([d], feeding=feeding, field=fields)
except:
trace = traceback.format_exc()
recvQ.put((False, trace))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册