未验证 提交 7d74dc26 编写于 作者: H haoyuying 提交者: GitHub

Add paddle.no_grad in predict()

上级 b87f08f5
...@@ -92,6 +92,7 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -92,6 +92,7 @@ class ImageClassifierModule(RunModule, ImageServing):
results(list[dict]) : The prediction result of each input image results(list[dict]) : The prediction result of each input image
''' '''
self.eval() self.eval()
with paddle.no_grad():
res = [] res = []
total_num = len(images) total_num = len(images)
loop_num = int(np.ceil(total_num / batch_size)) loop_num = int(np.ceil(total_num / batch_size))
...@@ -223,6 +224,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -223,6 +224,7 @@ class ImageColorizeModule(RunModule, ImageServing):
res(list[dict]) : The prediction result of each input image res(list[dict]) : The prediction result of each input image
''' '''
self.eval() self.eval()
with paddle.no_grad():
lab2rgb = T.LAB2RGB() lab2rgb = T.LAB2RGB()
res = [] res = []
total_num = len(images) total_num = len(images)
...@@ -393,6 +395,7 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -393,6 +395,7 @@ class Yolov3Module(RunModule, ImageServing):
labels(np.ndarray): Predict labels. labels(np.ndarray): Predict labels.
''' '''
self.eval() self.eval()
with paddle.no_grad():
boxes = [] boxes = []
scores = [] scores = []
self.downsample = 32 self.downsample = 32
...@@ -521,6 +524,7 @@ class StyleTransferModule(RunModule, ImageServing): ...@@ -521,6 +524,7 @@ class StyleTransferModule(RunModule, ImageServing):
output(list[np.ndarray]) : The style transformed images with bgr mode. output(list[np.ndarray]) : The style transformed images with bgr mode.
''' '''
self.eval() self.eval()
with paddle.no_grad():
style = paddle.to_tensor(self.transform(style).astype('float32')) style = paddle.to_tensor(self.transform(style).astype('float32'))
style = style.unsqueeze(0) style = style.unsqueeze(0)
...@@ -655,6 +659,7 @@ class ImageSegmentationModule(ImageServing, RunModule): ...@@ -655,6 +659,7 @@ class ImageSegmentationModule(ImageServing, RunModule):
output(list[np.ndarray]) : The segmentation mask. output(list[np.ndarray]) : The segmentation mask.
''' '''
self.eval() self.eval()
with paddle.no_grad():
result = [] result = []
total_num = len(images) total_num = len(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册