提交 690624f5 编写于 作者: M MRXLT

change lod info to numpy

上级 f798f761
...@@ -276,18 +276,17 @@ class Client(object): ...@@ -276,18 +276,17 @@ class Client(object):
result_map[name] = np.array(result_map[name]) result_map[name] = np.array(result_map[name])
result_map[name].shape = shape result_map[name].shape = shape
if name in self.lod_tensor_set: if name in self.lod_tensor_set:
result_map["{}.lod".format( result_map["{}.lod".format(name)] = np.array(
name)] = result_batch.get_lod(mi, name) result_batch.get_lod(mi, name))
elif self.fetch_names_to_type_[name] == float_type: elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(mi, name) result_map[name] = result_batch.get_float_by_name(mi, name)
shape = result_batch.get_shape(mi, name) shape = result_batch.get_shape(mi, name)
result_map[name] = np.array(result_map[name]) result_map[name] = np.array(result_map[name])
result_map[name].shape = shape result_map[name].shape = shape
if name in self.lod_tensor_set: if name in self.lod_tensor_set:
result_map["{}.lod".format( result_map["{}.lod".format(name)] = np.array(
name)] = result_batch.get_lod(mi, name) result_batch.get_lod(mi, name))
multi_result_map.append(result_map) multi_result_map.append(result_map)
ret = None ret = None
if len(model_engine_names) == 1: if len(model_engine_names) == 1:
# If only one model result is returned, the format of ret is result_map # If only one model result is returned, the format of ret is result_map
...@@ -298,7 +297,6 @@ class Client(object): ...@@ -298,7 +297,6 @@ class Client(object):
engine_name: multi_result_map[mi] engine_name: multi_result_map[mi]
for mi, engine_name in enumerate(model_engine_names) for mi, engine_name in enumerate(model_engine_names)
} }
# When using the A/B test, the tag of variant needs to be returned # When using the A/B test, the tag of variant needs to be returned
return ret if not need_variant_tag else [ return ret if not need_variant_tag else [
ret, self.result_handle_.variant_tag() ret, self.result_handle_.variant_tag()
......
...@@ -66,7 +66,7 @@ class WebService(object): ...@@ -66,7 +66,7 @@ class WebService(object):
del feed["fetch"] del feed["fetch"]
fetch_map = self.client_service.predict(feed=feed, fetch=fetch) fetch_map = self.client_service.predict(feed=feed, fetch=fetch)
for key in fetch_map: for key in fetch_map:
fetch_map[key] = fetch_map[key][0].tolist() fetch_map[key] = fetch_map[key].tolist()
result = self.postprocess( result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map) feed=request.json, fetch=fetch, fetch_map=fetch_map)
result = {"result": result} result = {"result": result}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册