提交 bf07b81c 编写于 作者: S syyxsxx

change deploy.py

上级 3a238aa0
...@@ -23,14 +23,8 @@ from six import text_type as _text_type ...@@ -23,14 +23,8 @@ from six import text_type as _text_type
from openvino.inference_engine import IECore from openvino.inference_engine import IECore
class Predictor: class Predictor:
def __init__(self, def __init__(self, model_xml, model_yaml, device="CPU"):
model_xml,
model_yaml,
device="CPU"):
self.device = device self.device = device
if not osp.exists(model_xml): if not osp.exists(model_xml):
print("model xml file is not exists in {}".format(model_xml)) print("model xml file is not exists in {}".format(model_xml))
...@@ -54,29 +48,28 @@ class Predictor: ...@@ -54,29 +48,28 @@ class Predictor:
to_rgb = True to_rgb = True
else: else:
to_rgb = False to_rgb = False
self.transforms = self.build_transforms(self.info['Transforms'], to_rgb) self.transforms = self.build_transforms(self.info['Transforms'],
to_rgb)
self.predictor, self.net = self.create_predictor() self.predictor, self.net = self.create_predictor()
self.total_time = 0 self.total_time = 0
self.count_num = 0 self.count_num = 0
def create_predictor(self): def create_predictor(self):
#initialization for specified device #initialization for specified device
print("Creating Inference Engine") print("Creating Inference Engine")
ie = IECore() ie = IECore()
print("Loading network files:\n\t{}\n\t{}".format(self.model_xml, self.model_bin)) print("Loading network files:\n\t{}\n\t{}".format(self.model_xml,
self.model_bin))
net = ie.read_network(model=self.model_xml, weights=self.model_bin) net = ie.read_network(model=self.model_xml, weights=self.model_bin)
net.batch_size = 1 net.batch_size = 1
network_config = {} network_config = {}
if self.device == "MYRIAD": if self.device == "MYRIAD":
network_config = {'VPU_HW_STAGES_OPTIMIZATION':'NO'} network_config = {'VPU_HW_STAGES_OPTIMIZATION': 'NO'}
exec_net = ie.load_network(network=net, device_name=self.device, network_config) exec_net = ie.load_network(
network=net, device_name=self.device, config=network_config)
return exec_net, net return exec_net, net
def build_transforms(self, transforms_info, to_rgb=True): def build_transforms(self, transforms_info, to_rgb=True):
if self.model_type == "classifier": if self.model_type == "classifier":
import transforms.cls_transforms as transforms import transforms.cls_transforms as transforms
...@@ -118,16 +111,15 @@ class Predictor: ...@@ -118,16 +111,15 @@ class Predictor:
else: else:
eval_transforms.transforms.append(arrange_transform(mode='test')) eval_transforms.transforms.append(arrange_transform(mode='test'))
def raw_predict(self, preprocessed_input): def raw_predict(self, preprocessed_input):
self.count_num += 1 self.count_num += 1
feed_dict = {} feed_dict = {}
if self.model_name == "YOLOv3": if self.model_name == "YOLOv3":
inputs = self.net.inputs inputs = self.net.inputs
for name in inputs: for name in inputs:
if(len(inputs[name].shape) == 2): if (len(inputs[name].shape) == 2):
feed_dict[name] = preprocessed_input['im_size'] feed_dict[name] = preprocessed_input['im_size']
elif(len(inputs[name].shape) == 4): elif (len(inputs[name].shape) == 4):
feed_dict[name] = preprocessed_input['image'] feed_dict[name] = preprocessed_input['image']
else: else:
pass pass
...@@ -142,7 +134,6 @@ class Predictor: ...@@ -142,7 +134,6 @@ class Predictor:
print("Processing output blob") print("Processing output blob")
return res return res
def preprocess(self, image): def preprocess(self, image):
res = dict() res = dict()
if self.model_type == "classifier": if self.model_type == "classifier":
...@@ -171,7 +162,6 @@ class Predictor: ...@@ -171,7 +162,6 @@ class Predictor:
res['im_info'] = im_info res['im_info'] = im_info
return res return res
def classifier_postprocess(self, preds, topk=1): def classifier_postprocess(self, preds, topk=1):
""" 对分类模型的预测结果做后处理 """ 对分类模型的预测结果做后处理
""" """
...@@ -218,14 +208,13 @@ class Predictor: ...@@ -218,14 +208,13 @@ class Predictor:
outputs = preds[output_name][0] outputs = preds[output_name][0]
result = [] result = []
for out in outputs: for out in outputs:
if(out[0] > 0): if (out[0] > 0):
result.append(out.tolist()) result.append(out.tolist())
else: else:
pass pass
print(result) print(result)
return result return result
def predict(self, image, topk=1, threshold=0.5): def predict(self, image, topk=1, threshold=0.5):
preprocessed_input = self.preprocess(image) preprocessed_input = self.preprocess(image)
model_pred = self.raw_predict(preprocessed_input) model_pred = self.raw_predict(preprocessed_input)
...@@ -236,4 +225,3 @@ class Predictor: ...@@ -236,4 +225,3 @@ class Predictor:
elif self.model_type == "segmenter": elif self.model_type == "segmenter":
results = self.segmenter_postprocess(model_pred, results = self.segmenter_postprocess(model_pred,
preprocessed_input) preprocessed_input)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -48,8 +47,6 @@ class Compose(ClsTransform): ...@@ -48,8 +47,6 @@ class Compose(ClsTransform):
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
def __call__(self, im, label=None): def __call__(self, im, label=None):
""" """
Args: Args:
...@@ -84,7 +81,9 @@ class Compose(ClsTransform): ...@@ -84,7 +81,9 @@ class Compose(ClsTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
......
...@@ -25,7 +25,6 @@ import cv2 ...@@ -25,7 +25,6 @@ import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from .ops import * from .ops import *
from .box_utils import *
class DetTransform: class DetTransform:
...@@ -143,7 +142,9 @@ class Compose(DetTransform): ...@@ -143,7 +142,9 @@ class Compose(DetTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -394,8 +395,6 @@ class Resize(DetTransform): ...@@ -394,8 +395,6 @@ class Resize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class Normalize(DetTransform): class Normalize(DetTransform):
"""对图像进行标准化。 """对图像进行标准化。
...@@ -440,8 +439,6 @@ class Normalize(DetTransform): ...@@ -440,8 +439,6 @@ class Normalize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class ArrangeYOLOv3(DetTransform): class ArrangeYOLOv3(DetTransform):
"""获取YOLOv3模型训练/验证/预测所需信息。 """获取YOLOv3模型训练/验证/预测所需信息。
...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform): ...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform):
return outputs return outputs
class ComposedYOLOv3Transforms(Compose): class ComposedYOLOv3Transforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下, """YOLOv3模型的图像预处理流程,具体如下,
训练阶段: 训练阶段:
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -48,8 +47,6 @@ class Compose(ClsTransform): ...@@ -48,8 +47,6 @@ class Compose(ClsTransform):
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
def __call__(self, im, label=None): def __call__(self, im, label=None):
""" """
Args: Args:
...@@ -84,7 +81,9 @@ class Compose(ClsTransform): ...@@ -84,7 +81,9 @@ class Compose(ClsTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
......
...@@ -25,7 +25,6 @@ import cv2 ...@@ -25,7 +25,6 @@ import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from .ops import * from .ops import *
from .box_utils import *
class DetTransform: class DetTransform:
...@@ -143,7 +142,9 @@ class Compose(DetTransform): ...@@ -143,7 +142,9 @@ class Compose(DetTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -394,8 +395,6 @@ class Resize(DetTransform): ...@@ -394,8 +395,6 @@ class Resize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class Normalize(DetTransform): class Normalize(DetTransform):
"""对图像进行标准化。 """对图像进行标准化。
...@@ -440,8 +439,6 @@ class Normalize(DetTransform): ...@@ -440,8 +439,6 @@ class Normalize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class ArrangeYOLOv3(DetTransform): class ArrangeYOLOv3(DetTransform):
"""获取YOLOv3模型训练/验证/预测所需信息。 """获取YOLOv3模型训练/验证/预测所需信息。
...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform): ...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform):
return outputs return outputs
class ComposedYOLOv3Transforms(Compose): class ComposedYOLOv3Transforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下, """YOLOv3模型的图像预处理流程,具体如下,
训练阶段: 训练阶段:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册