提交 3b2810d3 编写于 作者: M MRXLT

local_predict support list

上级 cc1f9ba0
...@@ -115,6 +115,13 @@ class Debugger(object): ...@@ -115,6 +115,13 @@ class Debugger(object):
inputs = [] inputs = []
for name in self.feed_names_: for name in self.feed_names_:
if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name])
if self.feed_types_[name] == 0:
feed[name] = feed[name].astype("int64")
else:
feed[name] = feed[name].astype("float32")
inputs.append(PaddleTensor(feed[name][np.newaxis, :])) inputs.append(PaddleTensor(feed[name][np.newaxis, :]))
outputs = self.predictor.run(inputs) outputs = self.predictor.run(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册