提交 bf07b81c 编写于 作者: S syyxsxx

change deploy.py

上级 3a238aa0
...@@ -23,14 +23,8 @@ from six import text_type as _text_type ...@@ -23,14 +23,8 @@ from six import text_type as _text_type
from openvino.inference_engine import IECore from openvino.inference_engine import IECore
class Predictor: class Predictor:
def __init__(self, def __init__(self, model_xml, model_yaml, device="CPU"):
model_xml,
model_yaml,
device="CPU"):
self.device = device self.device = device
if not osp.exists(model_xml): if not osp.exists(model_xml):
print("model xml file is not exists in {}".format(model_xml)) print("model xml file is not exists in {}".format(model_xml))
...@@ -54,29 +48,28 @@ class Predictor: ...@@ -54,29 +48,28 @@ class Predictor:
to_rgb = True to_rgb = True
else: else:
to_rgb = False to_rgb = False
self.transforms = self.build_transforms(self.info['Transforms'], to_rgb) self.transforms = self.build_transforms(self.info['Transforms'],
to_rgb)
self.predictor, self.net = self.create_predictor() self.predictor, self.net = self.create_predictor()
self.total_time = 0 self.total_time = 0
self.count_num = 0 self.count_num = 0
def create_predictor(self): def create_predictor(self):
#initialization for specified device #initialization for specified device
print("Creating Inference Engine") print("Creating Inference Engine")
ie = IECore() ie = IECore()
print("Loading network files:\n\t{}\n\t{}".format(self.model_xml, self.model_bin)) print("Loading network files:\n\t{}\n\t{}".format(self.model_xml,
self.model_bin))
net = ie.read_network(model=self.model_xml, weights=self.model_bin) net = ie.read_network(model=self.model_xml, weights=self.model_bin)
net.batch_size = 1 net.batch_size = 1
network_config = {} network_config = {}
if self.device == "MYRIAD": if self.device == "MYRIAD":
network_config = {'VPU_HW_STAGES_OPTIMIZATION':'NO'} network_config = {'VPU_HW_STAGES_OPTIMIZATION': 'NO'}
exec_net = ie.load_network(network=net, device_name=self.device, network_config) exec_net = ie.load_network(
network=net, device_name=self.device, config=network_config)
return exec_net, net return exec_net, net
def build_transforms(self, transforms_info, to_rgb=True): def build_transforms(self, transforms_info, to_rgb=True):
if self.model_type == "classifier": if self.model_type == "classifier":
import transforms.cls_transforms as transforms import transforms.cls_transforms as transforms
...@@ -97,8 +90,8 @@ class Predictor: ...@@ -97,8 +90,8 @@ class Predictor:
if hasattr(eval_transforms, 'to_rgb'): if hasattr(eval_transforms, 'to_rgb'):
eval_transforms.to_rgb = to_rgb eval_transforms.to_rgb = to_rgb
self.arrange_transforms(eval_transforms) self.arrange_transforms(eval_transforms)
return eval_transforms return eval_transforms
def arrange_transforms(self, eval_transforms): def arrange_transforms(self, eval_transforms):
if self.model_type == 'classifier': if self.model_type == 'classifier':
import transforms.cls_transforms as transforms import transforms.cls_transforms as transforms
...@@ -118,16 +111,15 @@ class Predictor: ...@@ -118,16 +111,15 @@ class Predictor:
else: else:
eval_transforms.transforms.append(arrange_transform(mode='test')) eval_transforms.transforms.append(arrange_transform(mode='test'))
def raw_predict(self, preprocessed_input): def raw_predict(self, preprocessed_input):
self.count_num += 1 self.count_num += 1
feed_dict = {} feed_dict = {}
if self.model_name == "YOLOv3": if self.model_name == "YOLOv3":
inputs = self.net.inputs inputs = self.net.inputs
for name in inputs: for name in inputs:
if(len(inputs[name].shape) == 2): if (len(inputs[name].shape) == 2):
feed_dict[name] = preprocessed_input['im_size'] feed_dict[name] = preprocessed_input['im_size']
elif(len(inputs[name].shape) == 4): elif (len(inputs[name].shape) == 4):
feed_dict[name] = preprocessed_input['image'] feed_dict[name] = preprocessed_input['image']
else: else:
pass pass
...@@ -137,14 +129,13 @@ class Predictor: ...@@ -137,14 +129,13 @@ class Predictor:
#Start sync inference #Start sync inference
print("Starting inference in synchronous mode") print("Starting inference in synchronous mode")
res = self.predictor.infer(inputs=feed_dict) res = self.predictor.infer(inputs=feed_dict)
#Processing output blob #Processing output blob
print("Processing output blob") print("Processing output blob")
return res return res
def preprocess(self, image): def preprocess(self, image):
res = dict() res = dict()
if self.model_type == "classifier": if self.model_type == "classifier":
im, = self.transforms(image) im, = self.transforms(image)
im = np.expand_dims(im, axis=0).copy() im = np.expand_dims(im, axis=0).copy()
...@@ -170,7 +161,6 @@ class Predictor: ...@@ -170,7 +161,6 @@ class Predictor:
res['image'] = im res['image'] = im
res['im_info'] = im_info res['im_info'] = im_info
return res return res
def classifier_postprocess(self, preds, topk=1): def classifier_postprocess(self, preds, topk=1):
""" 对分类模型的预测结果做后处理 """ 对分类模型的预测结果做后处理
...@@ -184,7 +174,7 @@ class Predictor: ...@@ -184,7 +174,7 @@ class Predictor:
'score': preds[output_name][0][l], 'score': preds[output_name][0][l],
} for l in pred_label] } for l in pred_label]
print(result) print(result)
return result return result
def segmenter_postprocess(self, preds, preprocessed_inputs): def segmenter_postprocess(self, preds, preprocessed_inputs):
""" 对语义分割结果做后处理 """ 对语义分割结果做后处理
...@@ -210,7 +200,7 @@ class Predictor: ...@@ -210,7 +200,7 @@ class Predictor:
raise Exception("Unexpected info '{}' in im_info".format(info[ raise Exception("Unexpected info '{}' in im_info".format(info[
0])) 0]))
return {'label_map': label_map, 'score_map': score_map} return {'label_map': label_map, 'score_map': score_map}
def detector_postprocess(self, preds, preprocessed_inputs): def detector_postprocess(self, preds, preprocessed_inputs):
"""对图像检测结果做后处理 """对图像检测结果做后处理
""" """
...@@ -218,14 +208,13 @@ class Predictor: ...@@ -218,14 +208,13 @@ class Predictor:
outputs = preds[output_name][0] outputs = preds[output_name][0]
result = [] result = []
for out in outputs: for out in outputs:
if(out[0] > 0): if (out[0] > 0):
result.append(out.tolist()) result.append(out.tolist())
else: else:
pass pass
print(result) print(result)
return result return result
def predict(self, image, topk=1, threshold=0.5): def predict(self, image, topk=1, threshold=0.5):
preprocessed_input = self.preprocess(image) preprocessed_input = self.preprocess(image)
model_pred = self.raw_predict(preprocessed_input) model_pred = self.raw_predict(preprocessed_input)
...@@ -235,5 +224,4 @@ class Predictor: ...@@ -235,5 +224,4 @@ class Predictor:
results = self.detector_postprocess(model_pred, preprocessed_input) results = self.detector_postprocess(model_pred, preprocessed_input)
elif self.model_type == "segmenter": elif self.model_type == "segmenter":
results = self.segmenter_postprocess(model_pred, results = self.segmenter_postprocess(model_pred,
preprocessed_input) preprocessed_input)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -48,8 +47,6 @@ class Compose(ClsTransform): ...@@ -48,8 +47,6 @@ class Compose(ClsTransform):
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
def __call__(self, im, label=None): def __call__(self, im, label=None):
""" """
Args: Args:
...@@ -84,7 +81,9 @@ class Compose(ClsTransform): ...@@ -84,7 +81,9 @@ class Compose(ClsTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
......
...@@ -25,7 +25,6 @@ import cv2 ...@@ -25,7 +25,6 @@ import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from .ops import * from .ops import *
from .box_utils import *
class DetTransform: class DetTransform:
...@@ -143,7 +142,9 @@ class Compose(DetTransform): ...@@ -143,7 +142,9 @@ class Compose(DetTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -394,8 +395,6 @@ class Resize(DetTransform): ...@@ -394,8 +395,6 @@ class Resize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class Normalize(DetTransform): class Normalize(DetTransform):
"""对图像进行标准化。 """对图像进行标准化。
...@@ -440,8 +439,6 @@ class Normalize(DetTransform): ...@@ -440,8 +439,6 @@ class Normalize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class ArrangeYOLOv3(DetTransform): class ArrangeYOLOv3(DetTransform):
"""获取YOLOv3模型训练/验证/预测所需信息。 """获取YOLOv3模型训练/验证/预测所需信息。
...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform): ...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform):
return outputs return outputs
class ComposedYOLOv3Transforms(Compose): class ComposedYOLOv3Transforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下, """YOLOv3模型的图像预处理流程,具体如下,
训练阶段: 训练阶段:
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug
import random import random
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -48,8 +47,6 @@ class Compose(ClsTransform): ...@@ -48,8 +47,6 @@ class Compose(ClsTransform):
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
def __call__(self, im, label=None): def __call__(self, im, label=None):
""" """
Args: Args:
...@@ -84,7 +81,9 @@ class Compose(ClsTransform): ...@@ -84,7 +81,9 @@ class Compose(ClsTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
......
...@@ -25,7 +25,6 @@ import cv2 ...@@ -25,7 +25,6 @@ import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from .ops import * from .ops import *
from .box_utils import *
class DetTransform: class DetTransform:
...@@ -143,7 +142,9 @@ class Compose(DetTransform): ...@@ -143,7 +142,9 @@ class Compose(DetTransform):
transform_names = [type(x).__name__ for x in self.transforms] transform_names = [type(x).__name__ for x in self.transforms]
for aug in augmenters: for aug in augmenters:
if type(aug).__name__ in transform_names: if type(aug).__name__ in transform_names:
print("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__)) print(
"{} is already in ComposedTransforms, need to remove it from add_augmenters().".
format(type(aug).__name__))
self.transforms = augmenters + self.transforms self.transforms = augmenters + self.transforms
...@@ -394,8 +395,6 @@ class Resize(DetTransform): ...@@ -394,8 +395,6 @@ class Resize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class Normalize(DetTransform): class Normalize(DetTransform):
"""对图像进行标准化。 """对图像进行标准化。
...@@ -440,8 +439,6 @@ class Normalize(DetTransform): ...@@ -440,8 +439,6 @@ class Normalize(DetTransform):
return (im, im_info, label_info) return (im, im_info, label_info)
class ArrangeYOLOv3(DetTransform): class ArrangeYOLOv3(DetTransform):
"""获取YOLOv3模型训练/验证/预测所需信息。 """获取YOLOv3模型训练/验证/预测所需信息。
...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform): ...@@ -491,8 +488,6 @@ class ArrangeYOLOv3(DetTransform):
return outputs return outputs
class ComposedYOLOv3Transforms(Compose): class ComposedYOLOv3Transforms(Compose):
"""YOLOv3模型的图像预处理流程,具体如下, """YOLOv3模型的图像预处理流程,具体如下,
训练阶段: 训练阶段:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册