提交 38c0dc4e 编写于 作者: T TeslaZhao 提交者: felixhjh

Merge pull request #1421 from felixhjh/develop

modify detection examples preprocess
上级 15a3dea8
...@@ -15,7 +15,7 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9292 --gpu_ ...@@ -15,7 +15,7 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9292 --gpu_
### Perform prediction ### Perform prediction
``` ```
python3 test_client.py python3 test_client.py 000000570688.jpg
``` ```
Image with bounding boxes and json result would be saved in `output` folder. Image with bounding boxes and json result would be saved in `output` folder.
...@@ -15,7 +15,7 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9292 --gpu_ ...@@ -15,7 +15,7 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9292 --gpu_
### 执行预测 ### 执行预测
``` ```
python3 test_client.py python3 test_client.py 000000570688.jpg
``` ```
客户端已经为图片做好了后处理,在`output`文件夹下存放各个框的json格式信息还有后处理结果图片。 客户端已经为图片做好了后处理,在`output`文件夹下存放各个框的json格式信息还有后处理结果图片。
...@@ -17,27 +17,30 @@ import numpy as np ...@@ -17,27 +17,30 @@ import numpy as np
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import cv2 import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize( preprocess = DetectionSequential([
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( DetectionFile2Image(),
(2, 0, 1)) DetectionResize((800, 1333), True, interpolation=2),
DetectionNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionTranspose((2,0,1)),
DetectionPadStride(32)
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9292']) client.connect(['127.0.0.1:9292'])
im = preprocess('000000570688.jpg') im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"im_shape": np.array(list(im.shape[1:])).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
print(fetch_map) print(fetch_map)
fetch_map["image"] = '000000570688.jpg' fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
...@@ -9,7 +9,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/fast ...@@ -9,7 +9,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/fast
### Start the service ### Start the service
``` ```
tar xf faster_rcnn_hrnetv2p_w18_1x.tar tar xf faster_rcnn_hrnetv2p_w18_1x.tar.gz
python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0 python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
......
...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/fast ...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/fast
### 启动服务 ### 启动服务
``` ```
tar xf faster_rcnn_hrnetv2p_w18_1x.tar tar xf faster_rcnn_hrnetv2p_w18_1x.tar.gz
python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0 python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项,但此时需要额外设置子图的TRT变长最大最小最优shape. 该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项,但此时需要额外设置子图的TRT变长最大最小最优shape.
......
...@@ -17,24 +17,27 @@ import numpy as np ...@@ -17,24 +17,27 @@ import numpy as np
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import cv2 import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize( preprocess = DetectionSequential([
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( DetectionFile2Image(),
(2, 0, 1)) DetectionResize((800, 1333), True, interpolation=2),
DetectionNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionTranspose((2,0,1)),
DetectionPadStride(32)
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"im_shape": np.array(list(im.shape[1:])).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
......
...@@ -12,15 +12,19 @@ ...@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys import sys
import numpy as np import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
preprocess = Sequential([ preprocess = DetectionSequential([
File2Image(), BGR2RGB(), Div(255.0), DetectionFile2Image(),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False), DetectionNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
Resize(640, 640), Transpose((2, 0, 1)) DetectionResize(
(800, 1333), True, interpolation=cv2.INTER_LINEAR),
DetectionTranspose((2,0,1)),
DetectionPadStride(128)
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output") postprocess = RCNNPostprocess("label_list.txt", "output")
...@@ -29,15 +33,14 @@ client = Client() ...@@ -29,15 +33,14 @@ client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"im_shape": np.array(list(im.shape[1:])).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
print(fetch_map)
fetch_map["image"] = sys.argv[1] fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
...@@ -17,23 +17,27 @@ import numpy as np ...@@ -17,23 +17,27 @@ import numpy as np
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import cv2 import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize( preprocess = DetectionSequential([
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( DetectionFile2Image(),
(2, 0, 1)) DetectionNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionResize(
(800, 1333), True, interpolation=cv2.INTER_LINEAR),
DetectionTranspose((2,0,1)),
DetectionPadStride(128)
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
......
...@@ -17,27 +17,29 @@ import numpy as np ...@@ -17,27 +17,29 @@ import numpy as np
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import cv2 import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize( preprocess = DetectionSequential([
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( DetectionFile2Image(),
(2, 0, 1)) DetectionNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionResize(
(608, 608), False, interpolation=2),
DetectionTranspose((2,0,1))
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"im_shape": np.array(list(im.shape[1:])).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
print(fetch_map)
fetch_map["image"] = sys.argv[1] fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
person aeroplane
bicycle bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird bird
boat
bottle
bus
car
cat cat
chair
cow
diningtable
dog dog
horse horse
motorbike
person
pottedplant
sheep sheep
cow sofa
elephant train
bear tvmonitor
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
...@@ -17,23 +17,27 @@ import numpy as np ...@@ -17,23 +17,27 @@ import numpy as np
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import cv2 import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize( preprocess = DetectionSequential([
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( DetectionFile2Image(),
(2, 0, 1)) DetectionResize(
(300, 300), False, interpolation=cv2.INTER_LINEAR),
DetectionNormalize([104.0, 117.0, 123.0], [1.0, 1.0, 1.0], False),
DetectionTranspose((2,0,1)),
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
### Get Model ### Get Model
``` ```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ttfnet_darknet53_1x_coco.tar wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/ttfnet_darknet53_1x_coco.tar
``` ```
### Start the service ### Start the service
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
## 获得模型 ## 获得模型
``` ```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ttfnet_darknet53_1x_coco.tar wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/ttfnet_darknet53_1x_coco.tar
``` ```
......
...@@ -11,16 +11,18 @@ ...@@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys import sys
import numpy as np import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
preprocess = Sequential([ preprocess = DetectionSequential([
File2Image(), BGR2RGB(), DetectionFile2Image(),
Normalize([123.675, 116.28, 103.53], [58.395, 57.12, 57.375], False), DetectionResize(
Resize((512, 512)), Transpose((2, 0, 1)) (512, 512), False, interpolation=cv2.INTER_LINEAR),
DetectionNormalize([123.675, 116.28, 103.53], [58.395, 57.12, 57.375], False),
DetectionTranspose((2,0,1))
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output") postprocess = RCNNPostprocess("label_list.txt", "output")
...@@ -29,11 +31,14 @@ client = Client() ...@@ -29,11 +31,14 @@ client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
......
...@@ -17,27 +17,29 @@ import numpy as np ...@@ -17,27 +17,29 @@ import numpy as np
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import cv2 import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Resize( preprocess = DetectionSequential([
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose( DetectionFile2Image(),
(2, 0, 1)) DetectionResize(
(608, 608), False, interpolation=2),
DetectionNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True),
DetectionTranspose((2,0,1)),
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494']) client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1]) im, im_info = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
feed={ feed={
"image": im, "image": im,
"im_shape": np.array(list(im.shape[1:])).reshape(-1), "im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1), "scale_factor": im_info['scale_factor'],
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
print(fetch_map)
fetch_map["image"] = sys.argv[1] fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from .chinese_bert_reader import ChineseBertReader from .chinese_bert_reader import ChineseBertReader
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, Base64ToImage from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, Base64ToImage
from .image_reader import DetectionFile2Image, DetectionSequential, DetectionNormalize, DetectionTranspose, DetectionResize, DetectionBGR2RGB, DetectionPadStride
from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor
from .image_reader import RCNNPostprocess, SegPostprocess, PadStride, BlazeFacePostprocess from .image_reader import RCNNPostprocess, SegPostprocess, PadStride, BlazeFacePostprocess
from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
......
...@@ -498,6 +498,42 @@ class Sequential(object): ...@@ -498,6 +498,42 @@ class Sequential(object):
return format_string_ return format_string_
class DetectionSequential(object):
"""
Args:
sequence (sequence of ``Transform`` objects): list of transforms to chain.
This API references some of the design pattern of torchvision
Users can simply use this API in training as well
Example:
>>> image_reader.Sequnece([
>>> transforms.CenterCrop(10),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, im):
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
for t in self.transforms:
im, im_info = t(im, im_info)
return im, im_info
def __repr__(self):
format_string_ = self.__class__.__name__ + '('
for t in self.transforms:
format_string_ += '\n'
format_string_ += ' {0}'.format(t)
format_string_ += '\n)'
return format_string_
class RGB2BGR(object): class RGB2BGR(object):
def __init__(self): def __init__(self):
pass pass
...@@ -520,6 +556,17 @@ class BGR2RGB(object): ...@@ -520,6 +556,17 @@ class BGR2RGB(object):
return self.__class__.__name__ + "()" return self.__class__.__name__ + "()"
class DetectionBGR2RGB(object):
def __init__(self):
pass
def __call__(self, img, img_info=None):
return img[:, :, ::-1], img_info
def __repr__(self):
return self.__class__.__name__ + "()"
class String2Image(object): class String2Image(object):
def __init__(self): def __init__(self):
pass pass
...@@ -556,6 +603,33 @@ class File2Image(object): ...@@ -556,6 +603,33 @@ class File2Image(object):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + "()" return self.__class__.__name__ + "()"
class DetectionFile2Image(object):
def __init__(self):
pass
def __call__(self, img_path, im_info=None):
if py_version == 2:
fin = open(img_path)
else:
fin = open(img_path, "rb")
sample = fin.read()
data = np.fromstring(sample, np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
'''
img = cv2.imread(img_path, -1)
channels = img.shape[2]
ori_h = img.shape[0]
ori_w = img.shape[1]
'''
if im_info is not None:
im_info['im_shape'] = np.array(img.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return img, im_info
def __repr__(self):
return self.__class__.__name__ + "()"
class URL2Image(object): class URL2Image(object):
def __init__(self): def __init__(self):
...@@ -607,6 +681,27 @@ class Div(object): ...@@ -607,6 +681,27 @@ class Div(object):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + "({})".format(self.value) return self.__class__.__name__ + "({})".format(self.value)
class DetectionDiv(object):
""" divide by some float number """
def __init__(self, value):
self.value = value
def __call__(self, img, img_info=None):
"""
Args:
img (numpy array): (int8 numpy array)
Returns:
img (numpy array): (float32 numpy array)
"""
img = img.astype('float32') / self.value
return img, img_info
def __repr__(self):
return self.__class__.__name__ + "({})".format(self.value)
class Normalize(object): class Normalize(object):
"""Normalize a tensor image with mean and standard deviation. """Normalize a tensor image with mean and standard deviation.
...@@ -643,6 +738,51 @@ class Normalize(object): ...@@ -643,6 +738,51 @@ class Normalize(object):
self.std) self.std)
class DetectionNormalize(object):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
is_scale (bool): whether need im / 255
"""
def __init__(self, mean, std, is_scale=True):
self.mean = mean
self.std = std
self.is_scale = is_scale
def __call__(self, im, im_info=None):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean,
self.std)
class Lambda(object): class Lambda(object):
"""Apply a user-defined lambda as a transform. """Apply a user-defined lambda as a transform.
Very shame to just copy from Very shame to just copy from
...@@ -716,6 +856,124 @@ class Resize(object): ...@@ -716,6 +856,124 @@ class Resize(object):
self.size, self.max_size, self.size, self.max_size,
_cv2_interpolation_to_str[self.interpolation]) _cv2_interpolation_to_str[self.interpolation])
class DetectionResize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interpolation=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interpolation = interpolation
def __call__(self, im, im_info=None):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interpolation)
if im_info is not None:
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
def __repr__(self):
return self.__class__.__name__ + '(size={0}, max_size={1}, interpolation={2})'.format(
self.size, self.max_size,
_cv2_interpolation_to_str[self.interpolation])
class PadStride(object):
def __init__(self, stride):
self.coarsest_stride = stride
def __call__(self, img):
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return img
im_c, im_h, im_w = img.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
im_info = {}
im_info['resize_shape'] = padding_im.shape[1:]
return padding_im
class DetectionPadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info=None):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class ResizeByFactor(object): class ResizeByFactor(object):
"""Resize the input numpy array Image to a size multiple of factor which is usually required by a network """Resize the input numpy array Image to a size multiple of factor which is usually required by a network
...@@ -768,24 +1026,6 @@ class ResizeByFactor(object): ...@@ -768,24 +1026,6 @@ class ResizeByFactor(object):
self.factor, self.max_side_len) self.factor, self.max_side_len)
class PadStride(object):
def __init__(self, stride):
self.coarsest_stride = stride
def __call__(self, img):
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return img
im_c, im_h, im_w = img.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
im_info = {}
im_info['resize_shape'] = padding_im.shape[1:]
return padding_im
class Transpose(object): class Transpose(object):
def __init__(self, transpose_target): def __init__(self, transpose_target):
self.transpose_target = transpose_target self.transpose_target = transpose_target
...@@ -799,6 +1039,19 @@ class Transpose(object): ...@@ -799,6 +1039,19 @@ class Transpose(object):
"({})".format(self.transpose_target) "({})".format(self.transpose_target)
return format_string return format_string
class DetectionTranspose(object):
def __init__(self, transpose_target):
self.transpose_target = transpose_target
def __call__(self, im, im_info=None):
im = F.transpose(im, self.transpose_target)
return im, im_info
def __repr__(self):
format_string = self.__class__.__name__ + \
"({})".format(self.transpose_target)
return format_string
class SortedBoxes(object): class SortedBoxes(object):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册