diff --git a/deploy/end2end_ppyoloe/README.md b/deploy/end2end_ppyoloe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d470dccffe7c9927eac6946d3ee47ea96c346a56 --- /dev/null +++ b/deploy/end2end_ppyoloe/README.md @@ -0,0 +1,99 @@ +# Export ONNX Model +## Download pretrain paddle models + +* [ppyoloe-s](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams) +* [ppyoloe-m](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_m_300e_coco.pdparams) +* [ppyoloe-l](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams) +* [ppyoloe-x](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_x_300e_coco.pdparams) +* [ppyoloe-s-400e](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_400e_coco.pdparams) + + +## Export paddle model for deploying + +```shell +python ./tools/export_model.py \ + -c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml \ + -o weights=ppyoloe_crn_s_300e_coco.pdparams \ + trt=True \ + exclude_nms=True \ + TestReader.inputs_def.image_shape=[3,640,640] \ + --output_dir ./ + +# if you want to try ppyoloe-s-400e model +python ./tools/export_model.py \ + -c configs/ppyoloe/ppyoloe_crn_s_400e_coco.yml \ + -o weights=ppyoloe_crn_s_400e_coco.pdparams \ + trt=True \ + exclude_nms=True \ + TestReader.inputs_def.image_shape=[3,640,640] \ + --output_dir ./ +``` + +## Check requirements +```shell +pip install onnx>=1.10.0 +pip install paddle2onnx +pip install onnx-simplifier +pip install onnx-graphsurgeon --index-url https://pypi.ngc.nvidia.com +# if use cuda-python infer, please install it +pip install cuda-python +# if use cupy infer, please install it +pip install cupy-cuda117 # cuda110-cuda117 are all available +``` + +## Export script +```shell +python ./deploy/end2end_ppyoloe/end2end.py \ + --model-dir ppyoloe_crn_s_300e_coco \ + --save-file ppyoloe_crn_s_300e_coco.onnx \ + --opset 11 \ + --batch-size 1 \ + --topk-all 100 \ + --iou-thres 0.6 \ + --conf-thres 0.4 +# if you want to try ppyoloe-s-400e model +python ./deploy/end2end_ppyoloe/end2end.py \ + --model-dir ppyoloe_crn_s_400e_coco \ + --save-file ppyoloe_crn_s_400e_coco.onnx \ + --opset 11 \ + --batch-size 1 \ + --topk-all 100 \ + --iou-thres 0.6 \ + --conf-thres 0.4 +``` +#### Description of all arguments + +- `--model-dir` : the path of ppyoloe export dir. +- `--save-file` : the path of export onnx. +- `--opset` : onnx opset version. +- `--img-size` : image size for exporting ppyoloe. +- `--batch-size` : batch size for exporting ppyoloe. +- `--topk-all` : topk objects for every image. +- `--iou-thres` : iou threshold for NMS algorithm. +- `--conf-thres` : confidence threshold for NMS algorithm. + +### TensorRT backend (TensorRT version>= 8.0.0) +#### TensorRT engine export +``` shell +/path/to/trtexec \ + --onnx=ppyoloe_crn_s_300e_coco.onnx \ + --saveEngine=ppyoloe_crn_s_300e_coco.engine \ + --fp16 # if export TensorRT fp16 model +# if you want to try ppyoloe-s-400e model +/path/to/trtexec \ + --onnx=ppyoloe_crn_s_400e_coco.onnx \ + --saveEngine=ppyoloe_crn_s_400e_coco.engine \ + --fp16 # if export TensorRT fp16 model +``` +#### TensorRT image infer + +``` shell +# cuda-python infer script +python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_300e_coco.engine +# cupy infer script +python ./deploy/end2end_ppyoloe/cupy-python.py ppyoloe_crn_s_300e_coco.engine +# if you want to try ppyoloe-s-400e model +python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_400e_coco.engine +# or +python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_400e_coco.engine +``` \ No newline at end of file diff --git a/deploy/end2end_ppyoloe/cuda-python.py b/deploy/end2end_ppyoloe/cuda-python.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7bd7c84b3eeaa6bea55416d8a5eabd37ac4d33 --- /dev/null +++ b/deploy/end2end_ppyoloe/cuda-python.py @@ -0,0 +1,161 @@ +import sys +import requests +import cv2 +import random +import time +import numpy as np +import tensorrt as trt +from cuda import cudart +from pathlib import Path +from collections import OrderedDict, namedtuple + + +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, r, (dw, dh) + + +w = Path(sys.argv[1]) + +assert w.exists() and w.suffix in ('.engine', '.plan'), 'Wrong engine path' + +names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush'] +colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)} + +url = 'https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg' +file = requests.get(url) +img = cv2.imdecode(np.frombuffer(file.content, np.uint8), 1) + +_, stream = cudart.cudaStreamCreate() + +mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1) +std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1) + +# Infer TensorRT Engine +Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) +logger = trt.Logger(trt.Logger.ERROR) +trt.init_libnvinfer_plugins(logger, namespace="") +with open(w, 'rb') as f, trt.Runtime(logger) as runtime: + model = runtime.deserialize_cuda_engine(f.read()) +bindings = OrderedDict() +fp16 = False # default updated below +for index in range(model.num_bindings): + name = model.get_binding_name(index) + dtype = trt.nptype(model.get_binding_dtype(index)) + shape = tuple(model.get_binding_shape(index)) + data = np.empty(shape, dtype=np.dtype(dtype)) + _, data_ptr = cudart.cudaMallocAsync(data.nbytes, stream) + bindings[name] = Binding(name, dtype, shape, data, data_ptr) + if model.binding_is_input(index) and dtype == np.float16: + fp16 = True +binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) +context = model.create_execution_context() + +image = img.copy() +image, ratio, dwdh = letterbox(image, auto=False) +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + +image_copy = image.copy() + +image = image.transpose((2, 0, 1)) +image = np.expand_dims(image, 0) +image = np.ascontiguousarray(image) + +im = image.astype(np.float32) +im /= 255 +im -= mean +im /= std + +_, image_ptr = cudart.cudaMallocAsync(im.nbytes, stream) +cudart.cudaMemcpyAsync(image_ptr, im.ctypes.data, im.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream) + +# warmup for 10 times +for _ in range(10): + tmp = np.random.randn(1, 3, 640, 640).astype(np.float32) + _, tmp_ptr = cudart.cudaMallocAsync(tmp.nbytes, stream) + binding_addrs['image'] = tmp_ptr + context.execute_v2(list(binding_addrs.values())) + +start = time.perf_counter() +binding_addrs['image'] = image_ptr +context.execute_v2(list(binding_addrs.values())) +print(f'Cost {(time.perf_counter() - start) * 1000}ms') + +nums = bindings['num_dets'].data +boxes = bindings['det_boxes'].data +scores = bindings['det_scores'].data +classes = bindings['det_classes'].data + +cudart.cudaMemcpyAsync(nums.ctypes.data, + bindings['num_dets'].ptr, + nums.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + stream) +cudart.cudaMemcpyAsync(boxes.ctypes.data, + bindings['det_boxes'].ptr, + boxes.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + stream) +cudart.cudaMemcpyAsync(scores.ctypes.data, + bindings['det_scores'].ptr, + scores.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + stream) +cudart.cudaMemcpyAsync(classes.ctypes.data, + bindings['det_classes'].ptr, + classes.data.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + stream) + +cudart.cudaStreamSynchronize(stream) +cudart.cudaStreamDestroy(stream) + +for i in binding_addrs.values(): + cudart.cudaFree(i) + +num = int(nums[0][0]) +box_img = boxes[0, :num].round().astype(np.int32) +score_img = scores[0, :num] +clss_img = classes[0, :num] +for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)): + name = names[int(clss)] + color = colors[name] + cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2) + cv2.putText(image_copy, name, (int(box[0]), int(box[1]) - 2), cv2.FONT_HERSHEY_SIMPLEX, + 0.75, [225, 255, 255], thickness=2) + +cv2.imshow('Result', cv2.cvtColor(image_copy, cv2.COLOR_RGB2BGR)) +cv2.waitKey(0) diff --git a/deploy/end2end_ppyoloe/cupy-python.py b/deploy/end2end_ppyoloe/cupy-python.py new file mode 100644 index 0000000000000000000000000000000000000000..a66eb77ecf3aa4c76c143050764429a2a06e8ba1 --- /dev/null +++ b/deploy/end2end_ppyoloe/cupy-python.py @@ -0,0 +1,131 @@ +import sys +import requests +import cv2 +import random +import time +import numpy as np +import cupy as cp +import tensorrt as trt +from PIL import Image +from collections import OrderedDict, namedtuple +from pathlib import Path + + +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, r, (dw, dh) + + +names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush'] +colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)} + +url = 'https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg' +file = requests.get(url) +img = cv2.imdecode(np.frombuffer(file.content, np.uint8), 1) + +w = Path(sys.argv[1]) + +assert w.exists() and w.suffix in ('.engine', '.plan'), 'Wrong engine path' + +mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1) +std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1) + +mean = cp.asarray(mean) +std = cp.asarray(std) + +# Infer TensorRT Engine +Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) +logger = trt.Logger(trt.Logger.INFO) +trt.init_libnvinfer_plugins(logger, namespace="") +with open(w, 'rb') as f, trt.Runtime(logger) as runtime: + model = runtime.deserialize_cuda_engine(f.read()) +bindings = OrderedDict() +fp16 = False # default updated below +for index in range(model.num_bindings): + name = model.get_binding_name(index) + dtype = trt.nptype(model.get_binding_dtype(index)) + shape = tuple(model.get_binding_shape(index)) + data = cp.empty(shape, dtype=cp.dtype(dtype)) + bindings[name] = Binding(name, dtype, shape, data, int(data.data.ptr)) + if model.binding_is_input(index) and dtype == np.float16: + fp16 = True +binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) +context = model.create_execution_context() + +image = img.copy() +image, ratio, dwdh = letterbox(image, auto=False) +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + +image_copy = image.copy() + +image = image.transpose((2, 0, 1)) +image = np.expand_dims(image, 0) +image = np.ascontiguousarray(image) + +im = cp.asarray(image) +im = im.astype(cp.float32) +im /= 255 +im -= mean +im /= std + +# warmup for 10 times +for _ in range(10): + tmp = cp.random.randn(1, 3, 640, 640).astype(cp.float32) + binding_addrs['image'] = int(tmp.data.ptr) + context.execute_v2(list(binding_addrs.values())) + +start = time.perf_counter() +binding_addrs['image'] = int(im.data.ptr) +context.execute_v2(list(binding_addrs.values())) +print(f'Cost {(time.perf_counter() - start) * 1000}ms') + +nums = bindings['num_dets'].data +boxes = bindings['det_boxes'].data +scores = bindings['det_scores'].data +classes = bindings['det_classes'].data + +num = int(nums[0][0]) +box_img = boxes[0, :num].round().astype(cp.int32) +score_img = scores[0, :num] +clss_img = classes[0, :num] +for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)): + name = names[int(clss)] + color = colors[name] + cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2) + cv2.putText(image_copy, name, (int(box[0]), int(box[1]) - 2), cv2.FONT_HERSHEY_SIMPLEX, + 0.75, [225, 255, 255], thickness=2) + +cv2.imshow('Result', cv2.cvtColor(image_copy, cv2.COLOR_RGB2BGR)) +cv2.waitKey(0) diff --git a/deploy/end2end_ppyoloe/end2end.py b/deploy/end2end_ppyoloe/end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfbf019a5d5755768e7defd573203a20a020ef7 --- /dev/null +++ b/deploy/end2end_ppyoloe/end2end.py @@ -0,0 +1,97 @@ +import argparse +import onnx +import onnx_graphsurgeon as gs +import numpy as np + +from pathlib import Path +from paddle2onnx.legacy.command import program2onnx +from collections import OrderedDict + + +def main(opt): + model_dir = Path(opt.model_dir) + save_file = Path(opt.save_file) + assert model_dir.exists() and model_dir.is_dir() + if save_file.is_dir(): + save_file = (save_file / model_dir.stem).with_suffix('.onnx') + elif save_file.is_file() and save_file.suffix != '.onnx': + save_file = save_file.with_suffix('.onnx') + input_shape_dict = {'image': [opt.batch_size, 3, *opt.img_size], + 'scale_factor': [opt.batch_size, 2]} + program2onnx(str(model_dir), str(save_file), + 'model.pdmodel', 'model.pdiparams', + opt.opset, input_shape_dict=input_shape_dict) + onnx_model = onnx.load(save_file) + try: + import onnxsim + onnx_model, check = onnxsim.simplify(onnx_model) + assert check, 'assert check failed' + except Exception as e: + print(f'Simplifier failure: {e}') + onnx.checker.check_model(onnx_model) + graph = gs.import_onnx(onnx_model) + graph.fold_constants() + graph.cleanup().toposort() + mul = concat = None + for node in graph.nodes: + if node.op == 'Div' and node.i(0).op == 'Mul': + mul = node.i(0) + if node.op == 'Concat' and node.o().op == 'Reshape' and node.o().o().op == 'ReduceSum': + concat = node + + assert mul.outputs[0].shape[1] == concat.outputs[0].shape[2], 'Something wrong in outputs shape' + + anchors = mul.outputs[0].shape[1] + classes = concat.outputs[0].shape[1] + + scores = gs.Variable(name='scores', shape=[opt.batch_size, anchors, classes], dtype=np.float32) + graph.layer(op='Transpose', name='lastTranspose', + inputs=[concat.outputs[0]], + outputs=[scores], + attrs=OrderedDict(perm=[0, 2, 1])) + + graph.inputs = [graph.inputs[0]] + + attrs = OrderedDict( + plugin_version="1", + background_class=-1, + max_output_boxes=opt.topk_all, + score_threshold=opt.conf_thres, + iou_threshold=opt.iou_thres, + score_activation=False, + box_coding=0, ) + outputs = [gs.Variable("num_dets", np.int32, [opt.batch_size, 1]), + gs.Variable("det_boxes", np.float32, [opt.batch_size, opt.topk_all, 4]), + gs.Variable("det_scores", np.float32, [opt.batch_size, opt.topk_all]), + gs.Variable("det_classes", np.int32, [opt.batch_size, opt.topk_all])] + graph.layer(op='EfficientNMS_TRT', name="batched_nms", + inputs=[mul.outputs[0], scores], + outputs=outputs, + attrs=attrs) + graph.outputs = outputs + graph.cleanup().toposort() + onnx.save(gs.export_onnx(graph), save_file) + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--model-dir', type=str, + default=None, + help='paddle static model') + parser.add_argument('--save-file', type=str, + default=None, + help='onnx model save path') + parser.add_argument('--opset', type=int, default=11, help='opset version') + parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') + parser.add_argument('--batch-size', type=int, default=1, help='batch size') + parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images') + parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS') + parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS') + opt = parser.parse_args() + opt.img_size *= 2 if len(opt.img_size) == 1 else 1 + return opt + + +if __name__ == '__main__': + opt = parse_opt() + main(opt)