提交 29d7de56 编写于 作者: J jiangjiajun

use program cache in predict

上级 d0bbe3ed
...@@ -264,7 +264,8 @@ class BaseClassifier(BaseAPI): ...@@ -264,7 +264,8 @@ class BaseClassifier(BaseAPI):
im = self.test_transforms(img_file) im = self.test_transforms(img_file)
result = self.exe.run(self.test_prog, result = self.exe.run(self.test_prog,
feed={'image': im}, feed={'image': im},
fetch_list=list(self.test_outputs.values())) fetch_list=list(self.test_outputs.values()),
use_program_cache=True)
pred_label = np.argsort(result[0][0])[::-1][:true_topk] pred_label = np.argsort(result[0][0])[::-1][:true_topk]
res = [{ res = [{
'category_id': l, 'category_id': l,
......
...@@ -398,7 +398,8 @@ class DeepLabv3p(BaseAPI): ...@@ -398,7 +398,8 @@ class DeepLabv3p(BaseAPI):
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
result = self.exe.run(self.test_prog, result = self.exe.run(self.test_prog,
feed={'image': im}, feed={'image': im},
fetch_list=list(self.test_outputs.values())) fetch_list=list(self.test_outputs.values()),
use_program_cache=True)
pred = result[0] pred = result[0]
pred = np.squeeze(pred).astype('uint8') pred = np.squeeze(pred).astype('uint8')
logit = result[1] logit = result[1]
......
...@@ -389,7 +389,8 @@ class FasterRCNN(BaseAPI): ...@@ -389,7 +389,8 @@ class FasterRCNN(BaseAPI):
'im_shape': im_shape 'im_shape': im_shape
}, },
fetch_list=list(self.test_outputs.values()), fetch_list=list(self.test_outputs.values()),
return_numpy=False) return_numpy=False,
use_program_cache=True)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs) for k, v in zip(list(self.test_outputs.keys()), outputs)
......
...@@ -357,7 +357,8 @@ class MaskRCNN(FasterRCNN): ...@@ -357,7 +357,8 @@ class MaskRCNN(FasterRCNN):
'im_shape': im_shape 'im_shape': im_shape
}, },
fetch_list=list(self.test_outputs.values()), fetch_list=list(self.test_outputs.values()),
return_numpy=False) return_numpy=False,
use_program_cache=True)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs) for k, v in zip(list(self.test_outputs.keys()), outputs)
......
...@@ -363,7 +363,8 @@ class YOLOv3(BaseAPI): ...@@ -363,7 +363,8 @@ class YOLOv3(BaseAPI):
feed={'image': im, feed={'image': im,
'im_size': im_size}, 'im_size': im_size},
fetch_list=list(self.test_outputs.values()), fetch_list=list(self.test_outputs.values()),
return_numpy=False) return_numpy=False,
use_program_cache=True)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(self.test_outputs.keys()), outputs) for k, v in zip(list(self.test_outputs.keys()), outputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册