diff --git a/deploy/TENSOR_RT.md b/deploy/TENSOR_RT.md new file mode 100644 index 0000000000000000000000000000000000000000..1ef67ce4256f9bf42f3fc1c5d5f769f973879cb2 --- /dev/null +++ b/deploy/TENSOR_RT.md @@ -0,0 +1,61 @@ +# TensorRT预测部署教程 +TensorRT是NVIDIA提出的用于统一模型部署的加速库,可以应用于V100、JETSON Xavier等硬件,它可以极大提高预测速度。Paddle TensorRT教程请参考文档[使用Paddle-TensorRT库预测](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html#) + +## 1. 安装PaddleInference预测库 +- Python安装包,请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release) 下载带有tensorrt的安装包进行安装 + +- CPP预测库,请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 下载带有TensorRT编译的预测库 + +- 如果Python和CPP官网没有提供已编译好的安装包或预测库,请参考[源码安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html) 自行编译 + +**注意:** +- 您的机器上TensorRT的版本需要跟您使用的预测库中TensorRT版本保持一致。 +- PaddleGAN中部署预测要求TensorRT版本 > 7.0。 + +## 2. 导出模型 +模型导出具体请参考文档[PaddleGAN模型导出教程](../EXPORT_MODEL.md)。 + +## 3. 开启TensorRT加速 +### 3.1 配置TensorRT +在使用Paddle预测库构建预测器配置config时,打开TensorRT引擎就可以了: + +``` +config->EnableUseGpu(100, 0); // 初始化100M显存,使用GPU ID为0 +config->GpuDeviceId(); // 返回正在使用的GPU ID +// 开启TensorRT预测,可提升GPU预测性能,需要使用带TensorRT的预测库 +config->EnableTensorRtEngine(1 << 20 /*workspace_size*/, + batch_size /*max_batch_size*/, + 3 /*min_subgraph_size*/, + AnalysisConfig::Precision::kFloat32 /*precision*/, + false /*use_static*/, + false /*use_calib_mode*/); + +``` + +### 3.2 TensorRT固定尺寸预测 + +以`msvsr`为例,使用固定尺寸输入预测: +``` +python tools/inference.py --model_path=/root/to/model --config-file /root/to/config --run_mode trt_fp32 --min_subgraph_size 20 --mode_type msvsr +``` + +## 4、常见问题QA +**Q:** 提示没有`tensorrt_op`
+**A:** 请检查是否使用带有TensorRT的Paddle Python包或预测库。 + +**Q:** 提示`op out of memory`
+**A:** 检查GPU是否是别人也在使用,请尝试使用空闲GPU + +**Q:** 提示`some trt inputs dynamic shape info not set`
+**A:** 这是由于`TensorRT`会把网络结果划分成多个子图,我们只设置了输入数据的动态尺寸,划分的其他子图的输入并未设置动态尺寸。有两个解决方法: + +- 方法一:通过增大`min_subgraph_size`,跳过对这些子图的优化。根据提示,设置min_subgraph_size大于并未设置动态尺寸输入的子图中OP个数即可。 +`min_subgraph_size`的意思是,在加载TensorRT引擎的时候,大于`min_subgraph_size`的OP才会被优化,并且这些OP是连续的且是TensorRT可以优化的。 + +- 方法二:找到子图的这些输入,按照上面方式也设置子图的输入动态尺寸。 + +**Q:** 如何打开日志
+**A:** 预测库默认是打开日志的,只要注释掉`config.disable_glog_info()`就可以打开日志 + +**Q:** 开启TensorRT,预测时提示Slice on batch axis is not supported in TensorRT
+**A:** 请尝试使用动态尺寸输入 diff --git a/deploy/serving/README.md b/deploy/serving/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9d16e022b5b4f44bc4a697a2a7973822c18ec2b8 --- /dev/null +++ b/deploy/serving/README.md @@ -0,0 +1,101 @@ +# 服务端预测部署 + +`PaddleGAN`训练出来的模型可以使用[Serving](https://github.com/PaddlePaddle/Serving) 部署在服务端。 +本教程以在REDS数据集上用`configs/msvsr_reds.yaml`算法训练的模型进行部署。 +预训练模型权重文件为[PP-MSVSR_reds_x4.pdparams](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams) 。 + +## 1. 安装 paddle serving +请参考[PaddleServing](https://github.com/PaddlePaddle/Serving/tree/v0.6.0) 中安装教程安装(版本>=0.6.0)。 + +## 2. 导出模型 +PaddleGAN在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleGAN/blob/develop/deploy/EXPORT_MODEL.md) + +``` +python tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --load /path/to/model --export_serving_model True +----output_dir /path/to/output +``` + +以上命令会在`/path/to/output`文件夹下生成一个`msvsr`文件夹: +``` +output +│ ├── multistagevsrmodel_generator +│ │ ├── multistagevsrmodel_generator.pdiparams +│ │ ├── multistagevsrmodel_generator.pdiparams.info +│ │ ├── multistagevsrmodel_generator.pdmodel +│ │ ├── serving_client +│ │ │ ├── serving_client_conf.prototxt +│ │ │ ├── serving_client_conf.stream.prototxt +│ │ ├── serving_server +│ │ │ ├── __model__ +│ │ │ ├── __params__ +│ │ │ ├── serving_server_conf.prototxt +│ │ │ ├── serving_server_conf.stream.prototxt +│ │ │ ├── ... +``` + +`serving_client`文件夹下`serving_client_conf.prototxt`详细说明了模型输入输出信息 +`serving_client_conf.prototxt`文件内容为: +``` +feed_var { + name: "lqs" + alias_name: "lqs" + is_lod_tensor: false + feed_type: 1 + shape: 1 + shape: 2 + shape: 3 + shape: 180 + shape: 320 +} +fetch_var { + name: "stack_18.tmp_0" + alias_name: "stack_18.tmp_0" + is_lod_tensor: false + fetch_type: 1 + shape: 1 + shape: 2 + shape: 3 + shape: 720 + shape: 1280 +} +fetch_var { + name: "stack_19.tmp_0" + alias_name: "stack_19.tmp_0" + is_lod_tensor: false + fetch_type: 1 + shape: 1 + shape: 3 + shape: 720 + shape: 1280 +} +``` + +## 4. 启动PaddleServing服务 + +``` +cd output_dir/multistagevsrmodel_generator/ + +# GPU +python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 + +# CPU +python -m paddle_serving_server.serve --model serving_server --port 9393 +``` + +## 5. 测试部署的服务 +``` +# 进入到导出模型文件夹 +cd output/msvsr/ +``` + +设置`prototxt`文件路径为`serving_client/serving_client_conf.prototxt` 。 +设置`fetch`为`fetch=["stack_19.tmp_0"])` + +测试 +``` +# 进入目录 +cd output/msvsr/ + +# 测试代码 test_client.py 会自动创建output文件夹,并在output下生成`res.mp4`文件 +python ../../deploy/serving/test_client.py input_video frame_num +``` diff --git a/deploy/serving/test_client.py b/deploy/serving/test_client.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e0ae73a50673840a376cdca37b69e2bfd74933 --- /dev/null +++ b/deploy/serving/test_client.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +from paddle_serving_client import Client +from paddle_serving_app.reader import * +import cv2 +import os +import imageio + +def get_img(pred): + pred = pred.squeeze() + pred = np.clip(pred, a_min=0., a_max=1.0) + pred = pred * 255 + pred = pred.round() + pred = pred.astype('uint8') + pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc + return pred + +preprocess = Sequential([ + BGR2RGB(), Resize( + (320, 180)), Div(255.0), Transpose( + (2, 0, 1)) +]) + +client = Client() + +client.load_client_config("serving_client/serving_client_conf.prototxt") +client.connect(['127.0.0.1:9393']) + +frame_num = int(sys.argv[2]) + +cap = cv2.VideoCapture(sys.argv[1]) +fps = cap.get(cv2.CAP_PROP_FPS) +size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) +success, frame = cap.read() +read_end = False +res_frames = [] +output_dir = "./output" +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +while success: + frames = [] + for i in range(frame_num): + if success: + frames.append(preprocess(frame)) + success, frame = cap.read() + else: + read_end = True + if read_end: break + + frames = np.stack(frames, axis=0) + fetch_map = client.predict( + feed={ + "lqs": frames, + }, + fetch=["stack_19.tmp_0"], + batch=False) + res_frames.extend([fetch_map["stack_19.tmp_0"][0][i] for i in range(frame_num)]) + +imageio.mimsave("output/output.mp4", + [get_img(frame) for frame in res_frames], + fps=fps) + diff --git a/docs/zh_CN/tutorials/wav2lip.md b/docs/zh_CN/tutorials/wav2lip.md index 6217c4dd3977626ee76a877eb6996e739d86ea50..fcc29dc6869b4febd24e17bd24605527950680d3 100644 --- a/docs/zh_CN/tutorials/wav2lip.md +++ b/docs/zh_CN/tutorials/wav2lip.md @@ -74,7 +74,7 @@ python -m paddle.distributed.launch \ ### 2.3 模型 Model|Dataset|BatchSize|Inference speed|Download ---|:--:|:--:|:--:|:--: -wa2lip_hq|LRS2| 1 | 0.2853s/image (GPU:P40) | [model](https://paddlegan.bj.bcebos.com/models/psgan_weight.pdparam://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams) +wa2lip_hq|LRS2| 1 | 0.2853s/image (GPU:P40) | [model](https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams) ## 3. 结果展示 diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index ae4ecd2bdc7ddd077640c5752fd7dabf14436ba7..8f8cf90c9d126283665cb58061e3eb58670d6ac4 100755 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -183,7 +183,7 @@ class BaseModel(ABC): for param in net.parameters(): param.trainable = requires_grad - def export_model(self, export_model, output_dir=None, inputs_size=[]): + def export_model(self, export_model, output_dir=None, inputs_size=[], export_serving_model=False): inputs_num = 0 for net in export_model: input_spec = [ @@ -201,3 +201,16 @@ class BaseModel(ABC): os.path.join( output_dir, '{}_{}'.format(self.__class__.__name__.lower(), net["name"]))) + if export_serving_model: + from paddle_serving_client.io import inference_model_to_serving + model_name = '{}_{}'.format(self.__class__.__name__.lower(), + net["name"]) + + inference_model_to_serving( + dirname=output_dir, + serving_server="{}/{}/serving_server".format(output_dir, + model_name), + serving_client="{}/{}/serving_client".format(output_dir, + model_name), + model_filename="{}.pdmodel".format(model_name), + params_filename="{}.pdiparams".format(model_name)) diff --git a/test_tipc/results/python_msvsr_results_fp32.txt b/test_tipc/results/python_msvsr_results_fp32.txt index 1462d709b7bf9fdf62504a7789cc9965e8fe8fa6..c29d1cf421f6c8254b18e72faf8db8c376e1d41d 100644 --- a/test_tipc/results/python_msvsr_results_fp32.txt +++ b/test_tipc/results/python_msvsr_results_fp32.txt @@ -1,2 +1,2 @@ -Metric psnr: 24.3250 -Metric ssim: 0.6497 \ No newline at end of file +c psnr: 27.2885 +Metric ssim: 0.7969 diff --git a/tools/export_model.py b/tools/export_model.py index 77120823de0f134de274c166afd4fbcdddf968f1..2c055c1d2893f1eee2c5dea9328c9846b864e56e 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -51,6 +51,12 @@ def parse_args(): type=str, help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--export_serving_model", + default=False, + type=bool, + help="export serving model.", + ) args = parser.parse_args() return args @@ -64,7 +70,7 @@ def main(args, cfg): for net_name, net in model.nets.items(): if net_name in state_dicts: net.set_state_dict(state_dicts[net_name]) - model.export_model(cfg.export_model, args.output_dir, inputs_size) + model.export_model(cfg.export_model, args.output_dir, inputs_size, args.export_serving_model) if __name__ == "__main__": diff --git a/tools/inference.py b/tools/inference.py index 5c4883a1b475455b18a536baaf604bc45dd76400..6322acd52c06c55e5b5a46c0aa69a083e81cf54a 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -58,11 +58,61 @@ def parse_args(): default=None, help='fix random numbers by setting seed\".' ) + # for tensorRT + parser.add_argument( + "--run_mode", + default="fluid", + type=str, + choices=["fluid", "trt_fp32", "trt_fp16"], + help="mode of running(fluid/trt_fp32/trt_fp16)") + parser.add_argument( + "--trt_min_shape", + default=1, + type=int, + help="trt_min_shape for tensorRT") + parser.add_argument( + "--trt_max_shape", + default=1280, + type=int, + help="trt_max_shape for tensorRT") + parser.add_argument( + "--trt_opt_shape", + default=640, + type=int, + help="trt_opt_shape for tensorRT") + parser.add_argument( + "--min_subgraph_size", + default=3, + type=int, + help="trt_opt_shape for tensorRT") + parser.add_argument( + "--batch_size", + default=1, + type=int, + help="batch_size for tensorRT") + parser.add_argument( + "--use_dynamic_shape", + dest="use_dynamic_shape", + action="store_true", + help="use_dynamic_shape for tensorRT") + parser.add_argument( + "--trt_calib_mode", + dest="trt_calib_mode", + action="store_true", + help="trt_calib_mode for tensorRT") args = parser.parse_args() return args -def create_predictor(model_path, device="gpu"): +def create_predictor(model_path, device="gpu", + run_mode='fluid', + batch_size=1, + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False): config = paddle.inference.Config(model_path + ".pdmodel", model_path + ".pdiparams") if device == "gpu": @@ -73,6 +123,34 @@ def create_predictor(model_path, device="gpu"): config.enable_xpu(100) else: config.disable_gpu() + + precision_map = { + 'trt_int8': paddle.inference.Config.Precision.Int8, + 'trt_fp32': paddle.inference.Config.Precision.Float32, + 'trt_fp16': paddle.inference.Config.Precision.Half + } + if run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=1 << 25, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[run_mode], + use_static=False, + use_calib_mode=trt_calib_mode) + + if use_dynamic_shape: + min_input_shape = { + 'image': [batch_size, 3, trt_min_shape, trt_min_shape] + } + max_input_shape = { + 'image': [batch_size, 3, trt_max_shape, trt_max_shape] + } + opt_input_shape = { + 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] + } + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + print('trt set dynamic shape done!') predictor = paddle.inference.create_predictor(config) return predictor @@ -95,11 +173,21 @@ def main(): random.seed(args.seed) np.random.seed(args.seed) cfg = get_config(args.config_file, args.opt) - predictor = create_predictor(args.model_path, args.device) + predictor = create_predictor(args.model_path, + args.device, + args.run_mode, + args.batch_size, + args.min_subgraph_size, + args.use_dynamic_shape, + args.trt_min_shape, + args.trt_max_shape, + args.trt_opt_shape, + args.trt_calib_mode) input_handles = [ predictor.get_input_handle(name) for name in predictor.get_input_names() ] + output_handle = predictor.get_output_handle(predictor.get_output_names()[0]) test_dataloader = build_dataloader(cfg.dataset.test, is_train=False, @@ -196,9 +284,12 @@ def main(): lq = data['lq'].numpy() input_handles[0].copy_from_cpu(lq) predictor.run() + if len(predictor.get_output_names()) > 1: + output_handle = predictor.get_output_handle(predictor.get_output_names()[-1]) prediction = output_handle.copy_to_cpu() prediction = paddle.to_tensor(prediction) _, t, _, _, _ = prediction.shape + out_img = [] gt_img = [] for ti in range(t):