提交 946b10c1 编写于 作者: W wangjiawei04

zero copy run and faster web service

上级 f3ba77e0
......@@ -72,7 +72,7 @@ class Debugger(object):
config.enable_profile()
config.set_cpu_math_library_num_threads(cpu_num)
config.switch_ir_optim(False)
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None):
......@@ -113,23 +113,30 @@ class Debugger(object):
"Fetch names should not be empty or out of saved fetch list.")
return {}
inputs = []
for name in self.feed_names_:
input_names = self.predictor.get_input_names()
for name in input_names:
if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name])
if self.feed_types_[name] == 0:
feed[name] = feed[name].astype("int64")
else:
feed[name] = feed[name].astype("float32")
inputs.append(PaddleTensor(feed[name]))
outputs = self.predictor.run(inputs)
if self.feed_types_[name] == 0:
feed[name] = feed[name].astype("int64")
else:
feed[name] = feed[name].astype("float32")
input_tensor = self.predictor.get_input_tensor(name)
input_tensor.copy_from_cpu(feed[name])
output_tensors = []
output_names = self.predictor.get_output_names()
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
output_tensors.append(output_tensor)
outputs = []
self.predictor.zero_copy_run()
for output_tensor in output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
fetch_map = {}
for name in fetch:
fetch_map[name] = outputs[self.fetch_names_to_idx_[
name]].as_ndarray()
if len(outputs[self.fetch_names_to_idx_[name]].lod) > 0:
fetch_map[name + ".lod"] = outputs[self.fetch_names_to_idx_[
name]].lod[0]
for i, name in enumerate(fetch):
fetch_map[name] = outputs[i]
if len(output_tensors[i].lod()) > 0:
fetch_map[name + ".lod"] = output_tensors[i].lod()[0]
return fetch_map
......@@ -127,10 +127,9 @@ class WebService(object):
request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"]
if len(feed) == 0:
raise ValueError("empty input")
fetch_map = self.client.predict(feed=feed, fetch=fetch)
for key in fetch_map:
if isinstance(fetch_map[key], np.ndarray):
fetch_map[key] = fetch_map[key].tolist()
result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册