提交 946b10c1 编写于 作者: W wangjiawei04

zero copy run and faster web service

上级 f3ba77e0
...@@ -72,7 +72,7 @@ class Debugger(object): ...@@ -72,7 +72,7 @@ class Debugger(object):
config.enable_profile() config.enable_profile()
config.set_cpu_math_library_num_threads(cpu_num) config.set_cpu_math_library_num_threads(cpu_num)
config.switch_ir_optim(False) config.switch_ir_optim(False)
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config) self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None): def predict(self, feed=None, fetch=None):
...@@ -113,8 +113,8 @@ class Debugger(object): ...@@ -113,8 +113,8 @@ class Debugger(object):
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.")
return {} return {}
inputs = [] input_names = self.predictor.get_input_names()
for name in self.feed_names_: for name in input_names:
if isinstance(feed[name], list): if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[ feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name]) name])
...@@ -122,14 +122,21 @@ class Debugger(object): ...@@ -122,14 +122,21 @@ class Debugger(object):
feed[name] = feed[name].astype("int64") feed[name] = feed[name].astype("int64")
else: else:
feed[name] = feed[name].astype("float32") feed[name] = feed[name].astype("float32")
inputs.append(PaddleTensor(feed[name])) input_tensor = self.predictor.get_input_tensor(name)
input_tensor.copy_from_cpu(feed[name])
outputs = self.predictor.run(inputs) output_tensors = []
output_names = self.predictor.get_output_names()
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
output_tensors.append(output_tensor)
outputs = []
self.predictor.zero_copy_run()
for output_tensor in output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
fetch_map = {} fetch_map = {}
for name in fetch: for i, name in enumerate(fetch):
fetch_map[name] = outputs[self.fetch_names_to_idx_[ fetch_map[name] = outputs[i]
name]].as_ndarray() if len(output_tensors[i].lod()) > 0:
if len(outputs[self.fetch_names_to_idx_[name]].lod) > 0: fetch_map[name + ".lod"] = output_tensors[i].lod()[0]
fetch_map[name + ".lod"] = outputs[self.fetch_names_to_idx_[
name]].lod[0]
return fetch_map return fetch_map
...@@ -127,10 +127,9 @@ class WebService(object): ...@@ -127,10 +127,9 @@ class WebService(object):
request.json["fetch"]) request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed: if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"] del feed["fetch"]
if len(feed) == 0:
raise ValueError("empty input")
fetch_map = self.client.predict(feed=feed, fetch=fetch) fetch_map = self.client.predict(feed=feed, fetch=fetch)
for key in fetch_map:
if isinstance(fetch_map[key], np.ndarray):
fetch_map[key] = fetch_map[key].tolist()
result = self.postprocess( result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map) feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result} result = {"result": result}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册