diff --git a/applications/tools/first-order-demo.py b/applications/tools/first-order-demo.py index 588e9c364271ce9573444862a4d179cb65a23011..c74b965f2a296a47bf5d34ca31356e81845129be 100644 --- a/applications/tools/first-order-demo.py +++ b/applications/tools/first-order-demo.py @@ -83,9 +83,15 @@ parser.add_argument( dest="face_enhancement", action="store_true", help="use face enhance for face") +parser.add_argument( + "--mobile_net", + dest="mobile_net", + action="store_true", + help="use mobile_net for fom") parser.set_defaults(relative=False) parser.set_defaults(adapt_scale=False) parser.set_defaults(face_enhancement=False) +parser.set_defaults(mobile_net=False) if __name__ == "__main__": args = parser.parse_args() @@ -105,5 +111,7 @@ if __name__ == "__main__": multi_person=args.multi_person, image_size=args.image_size, batch_size=args.batch_size, - face_enhancement=args.face_enhancement) + face_enhancement=args.face_enhancement, + mobile_net=args.mobile_net) predictor.run(args.source_image, args.driving_video) + diff --git a/configs/firstorder_vox_mobile_256.yaml b/configs/firstorder_vox_mobile_256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1eb26cd8bfcbd7ba620ee446a267ccc72afde1d --- /dev/null +++ b/configs/firstorder_vox_mobile_256.yaml @@ -0,0 +1,130 @@ +epochs: 100 +output_dir: output_dir + +dataset: + train: + name: FirstOrderDataset + batch_size: 1 + num_workers: 1 + use_shared_memory: False + phase: train + dataroot: data/first_order/Voxceleb/ + frame_shape: [256, 256, 3] + id_sampling: True + pairs_list: None + time_flip: True + num_repeats: 75 + create_frames_folder: False + transforms: + - name: PairedRandomHorizontalFlip + prob: 0.5 + keys: [image, image] + - name: PairedColorJitter + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + keys: [image, image] + test: + name: FirstOrderDataset + dataroot: data/first_order/Voxceleb/ + phase: test + batch_size: 1 + num_workers: 1 + time_flip: False + id_sampling: False + create_frames_folder: False + frame_shape: [ 256, 256, 3 ] + + +model: + name: FirstOrderModel + common_params: + num_kp: 10 + num_channels: 3 + estimate_jacobian: True + generator: + name: FirstOrderGenerator + kp_detector_cfg: + temperature: 0.1 + block_expansion: 32 + max_features: 256 + scale_factor: 0.25 + num_blocks: 5 + mobile_net: True + generator_cfg: + block_expansion: 32 + max_features: 256 + num_down_blocks: 2 + num_bottleneck_blocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 256 + num_blocks: 5 + scale_factor: 0.25 + mobile_net: True + discriminator: + name: FirstOrderDiscriminator + discriminator_cfg: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + train_params: + num_epochs: 100 + scales: [1, 0.5, 0.25, 0.125] + checkpoint_freq: 50 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + generator_gan: 1 + discriminator_gan: 1 + feature_matching: [10, 10, 10, 10] + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + equivariance_jacobian: 10 + +lr_scheduler: + name: MultiStepDecay + epoch_milestones: [237360, 356040] + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + lr_kp_detector: 2.0e-4 + +reconstruction_params: + num_videos: 1000 + format: '.mp4' + +animate_params: + num_pairs: 50 + format: '.mp4' + normalization_params: + adapt_movement_scale: False + use_relative_movement: True + use_relative_jacobian: True + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' + +log_config: + interval: 10 + visiual_interval: 10 + +validate: + interval: 3000 + save_img: true + +snapshot_config: + interval: 1 + +optimizer: + name: Adam + +export_model: + - {} diff --git a/deploy/lite/README.md b/deploy/lite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7fade218c2a49e2fdfe36a7259fa4d739351bd38 --- /dev/null +++ b/deploy/lite/README.md @@ -0,0 +1,177 @@ +# Paddle-Lite端侧部署 + +本教程将介绍基于[Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在移动端部署PaddleDetection模型的详细步骤。 + +Paddle Lite是飞桨轻量化推理引擎,为手机、IOT端提供高效推理能力,并广泛整合跨平台硬件,为端侧部署及应用落地问题提供轻量化的部署方案。 + +## 1. 准备环境 + +### 运行准备 +- 电脑(编译Paddle Lite) +- 安卓手机(armv7或armv8) + +### 1.1 准备交叉编译环境 +交叉编译环境用于编译 Paddle Lite 和 PaddleDetection 的C++ demo。 +支持多种开发环境,不同开发环境的编译流程请参考对应文档,请确保安装完成Java jdk、Android NDK(R17以上)。 + +1. [Docker](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#docker) +2. [Linux](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#linux) +3. [MAC OS](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#mac-os) + +### 1.2 准备预测库 + +预测库有两种获取方式: +1. 直接下载,预测库下载链接如下: + |平台|预测库下载链接| + |-|-| + |Android|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.8/inference_lite_lib.android.armv7.gcc.c++_static.with_extra.with_cv.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.8/inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv.tar.gz)| + +**注意**:1. 目前FOM的算子只在PaddleLite的develop版本中支持,需要自行下载编译 2.如果是从 Paddle-Lite [官方文档](https://paddle-lite.readthedocs.io/zh/latest/quick_start/release_lib.html#android-toolchain-gcc)下载的预测库,注意选择`with_extra=ON,with_cv=ON`的下载链接。3. 目前只提供Android端demo. + + +2. 编译Paddle-Lite得到预测库,Paddle-Lite的编译方式如下: +```shell +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite +# 如果使用编译方式,建议使用develop分支编译预测库 +git checkout develop +./lite/tools/build_android.sh --arch=armv8 --with_cv=ON --with_extra=ON +``` + +**注意**:编译Paddle-Lite获得预测库时,需要打开`--with_cv=ON --with_extra=ON`两个选项,`--arch`表示`arm`版本,这里指定为armv8,更多编译命令介绍请参考[链接](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_andriod.html#id2)。 + +直接下载预测库并解压后,可以得到`inference_lite_lib.android.armv8.gcc.c++_static.with_extra.with_cv/`文件夹,通过编译Paddle-Lite得到的预测库位于`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/`文件夹下。 +预测库的文件目录如下: + +``` +inference_lite_lib.android.armv8/ +|-- cxx C++ 预测库和头文件 +| |-- include C++ 头文件 +| | |-- paddle_api.h +| | |-- paddle_image_preprocess.h +| | |-- paddle_lite_factory_helper.h +| | |-- paddle_place.h +| | |-- paddle_use_kernels.h +| | |-- paddle_use_ops.h +| | `-- paddle_use_passes.h +| `-- lib C++预测库 +| |-- libpaddle_api_light_bundled.a C++静态库 +| `-- libpaddle_light_api_shared.so C++动态库 +|-- java Java预测库 +| |-- jar +| | `-- PaddlePredictor.jar +| |-- so +| | `-- libpaddle_lite_jni.so +| `-- src +|-- demo C++和Java示例代码 +| |-- cxx C++ 预测库demo +| `-- java Java 预测库demo +``` + +## 2 开始运行 + +### 2.1 模型优化 + +Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括量化、子图融合、混合调度、Kernel优选等方法,使用Paddle-Lite的`opt`工具可以自动对inference模型进行优化,目前支持两种优化方式,优化后的模型更轻量,模型运行速度更快。 + +**注意**:如果已经准备好了 `.nb` 结尾的模型文件,可以跳过此步骤。 + +#### 2.1.1 安装paddle_lite_opt工具 +安装paddle_lite_opt工具有如下两种方法: +1. [**建议**]pip安装paddlelite并进行转换 + ```shell + pip install paddlelite + ``` + +2. 源码编译Paddle-Lite生成opt工具 + + 模型优化需要Paddle-Lite的`opt`可执行文件,可以通过编译Paddle-Lite源码获得,编译步骤如下: + ```shell + # 如果准备环境时已经clone了Paddle-Lite,则不用重新clone Paddle-Lite + git clone https://github.com/PaddlePaddle/Paddle-Lite.git + cd Paddle-Lite + git checkout develop + # 启动编译 + ./lite/tools/build.sh build_optimize_tool + ``` + + 编译完成后,`opt`文件位于`build.opt/lite/api/`下,可通过如下方式查看`opt`的运行选项和使用方式; + ```shell + cd build.opt/lite/api/ + ./opt + ``` + + `opt`的使用方式与参数与上面的`paddle_lite_opt`完全一致。 + +之后使用`paddle_lite_opt`工具可以进行inference模型的转换。`paddle_lite_opt`的部分参数如下: + +|选项|说明| +|-|-| +|--model_file|待优化的PaddlePaddle模型(combined形式)的网络结构文件路径| +|--param_file|待优化的PaddlePaddle模型(combined形式)的权重文件路径| +|--optimize_out_type|输出模型类型,目前支持两种类型:protobuf和naive_buffer,其中naive_buffer是一种更轻量级的序列化/反序列化实现,默认为naive_buffer| +|--optimize_out|优化模型的输出路径| +|--valid_targets|指定模型可执行的backend,默认为arm。目前可支持x86、arm、opencl、npu、xpu,可以同时指定多个backend(以空格分隔),Model Optimize Tool将会自动选择最佳方式。如果需要支持华为NPU(Kirin 810/990 Soc搭载的达芬奇架构NPU),应当设置为npu, arm| + +更详细的`paddle_lite_opt`工具使用说明请参考[使用opt转化模型文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/opt/opt_bin.html) + +`--model_file`表示inference模型的model文件地址,`--param_file`表示inference模型的param文件地址;`optimize_out`用于指定输出文件的名称(不需要添加`.nb`的后缀)。直接在命令行中运行`paddle_lite_opt`,也可以查看所有参数及其说明。 + + +#### 2.1.3 FOM转换示例 +```shell +# 将inference模型转化为Paddle-Lite优化模型 +paddle_lite_opt --model_file=output_inference/fom_dy2st/generator.pdmodel \ + --param_file=output_inference/fom_dy2st/generator.pdiparams \ + --optimize_out=output_inference/fom_dy2st/generator_lite \ + --optimize_out_type=naive_buffer \ + --valid_targets=arm +paddle_lite_opt --model_file=output_inference/fom_dy2st/kp_detector.pdmodel \ + --param_file=output_inference/fom_dy2st/kp_detector.pdiparams \ + --optimize_out=output_inference/fom_dy2st/kp_detector_lite \ + --optimize_out_type=naive_buffer \ + --valid_targets=arm +``` + +最终在当前文件夹下生成`generator_lite.nb`和`kp_detector_lite.nb`的文件。 + +**注意**:`--optimize_out` 参数为优化后模型的保存路径,无需加后缀`.nb`;`--model_file` 参数为模型结构信息文件的路径,`--param_file` 参数为模型权重信息文件的路径,请注意文件名。 + +### 2.2 与手机联调 + +首先需要进行一些准备工作。 +1. 准备一台arm8的安卓手机,如果编译的预测库和opt文件是armv7,则需要arm7的手机,并修改Makefile中`ARM_ABI = arm7`。 +2. 电脑上安装ADB工具,用于调试。 ADB安装方式如下: + + 2.1. MAC电脑安装ADB: + + ```shell + brew cask install android-platform-tools + ``` + 2.2. Linux安装ADB + ```shell + sudo apt update + sudo apt install -y wget adb + ``` + 2.3. Window安装ADB + + win上安装需要去谷歌的安卓平台下载ADB软件包进行安装:[链接](https://developer.android.com/studio) + +3. 手机连接电脑后,开启手机`USB调试`选项,选择`文件传输`模式,在电脑终端中输入: + +```shell +adb devices +``` +如果有device输出,则表示安装成功,如下所示: +``` +List of devices attached +744be294 device +``` + +4. 准备优化后的模型、预测库文件、测试图像和类别映射文件, 导入手机等运行,目前apk还存在一些效果问题,还在优化中。 + + +## FAQ +Q1:如果想更换模型怎么办,需要重新按照流程走一遍吗? +A1:如果已经走通了上述步骤,更换模型只需要替换 `.nb` 模型文件即可,同时要注意修改下配置文件中的 `.nb` 文件路径以及类别映射文件(如有必要)。 + diff --git a/docs/imgs/father_23.jpg b/docs/imgs/father_23.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f9422b4b79eb45d462867eb694ff7d02c39772d Binary files /dev/null and b/docs/imgs/father_23.jpg differ diff --git a/docs/imgs/mayiyahei.MP4 b/docs/imgs/mayiyahei.MP4 new file mode 100644 index 0000000000000000000000000000000000000000..1092c616e0ba7b3c793f67d5c4333f0585f1f95e Binary files /dev/null and b/docs/imgs/mayiyahei.MP4 differ diff --git a/docs/zh_CN/tutorials/motion_driving.md b/docs/zh_CN/tutorials/motion_driving.md index 65876c4eef82e1ebc031e0c92e4bcd0a41e74091..cd6c00938bd85271a940c7bba1b3f03a5c79ba58 100644 --- a/docs/zh_CN/tutorials/motion_driving.md +++ b/docs/zh_CN/tutorials/motion_driving.md @@ -68,6 +68,7 @@ python -u tools/first-order-demo.py \ ``` #### 参数说明: + | 参数 | 使用说明 | | ---------------- | ------------------------------------------------------------ | | driving_video | 驱动视频,视频中人物的表情动作作为待迁移的对象。 | @@ -131,6 +132,52 @@ python -m paddle.distributed.launch \ +### 2. 模型压缩 +数据处理同上述,模型分为kp_detector和generator,首先固定原始generator部分,训练压缩版的kp_detector部分,然后固定原始kp_detector部分,去训练generator部分,最后将两个压缩的模型一起训练,同时添加中间的蒸馏loss。 + +**预测:** +``` +cd applications/ +python -u tools/first-order-demo.py \ + --driving_video ../docs/imgs/mayiyahei.MP4 \ + --source_image ../docs/imgs/father_23.jpg \ + --config ../configs/firstorder_vox_mobile_256.yaml \ + --ratio 0.4 \ + --relative \ + --adapt_scale \ + --mobile_net +``` +目前压缩采用mobilenet+剪枝的方法,和之前对比: +| | 大小(M) | reconstruction loss | +| :--------------: | :--------------: | :-----------------: | +| 原始 | 229 | 0.012058867 | +| 压缩 | 6.1 | 0.015025159 | + +### 3. 模型部署 +#### 3.1 导出模型 +使用`tools/fom_export.py`脚本导出模型已经部署时使用的配置文件,配置文件名字为`firstorder_vox_mobile_256.yml`。模型导出脚本如下: +```bash +# 导出FOM模型 +需要将 “/ppgan/modules/first_order.py”中的nn.SyncBatchNorm 改为nn.BatchNorm,因为export目前不支持SyncBatchNorm +将 out = out[:, :, ::int_inv_scale, ::int_inv_scale] 改为 +out = paddle.fluid.layers.resize_nearest(out, scale=self.scale) + +python tools/export_model.py \ + --config-file configs/firstorder_vox_mobile_256.yaml \ + --load /root/.cache/ppgan/vox_mobile.pdparams \ + --inputs_size "1,3,256,256;1,3,256,256;1,10,2;1,10,2,2" \ + --export_model output_inference/ +``` +预测模型会导出到`output_inference/fom_dy2st/`目录下,分别为`model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`。 + + +#### 3.3 PaddleLite部署 +- [使用PaddleLite部署FOM模型](./lite/README.md) +- [FOM-Lite-Demo](https://paddlegan.bj.bcebos.com/applications/first_order_model/paddle_lite/apk/face_detection_demo%202.zip)。更多内容,请参考[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) +目前问题: +(a).paddlelite运行效果略差于inference,正在优化中 +(b).单线程跑generator,帧数多了会跑到小核不跑大核 + ## 参考文献 ``` @@ -143,3 +190,4 @@ python -m paddle.distributed.launch \ } ``` + diff --git a/ppgan/apps/first_order_predictor.py b/ppgan/apps/first_order_predictor.py index a225d20b61680d7fed9a5125a7a17c3bedb1045b..50e3eecd1728bb70dc5107df31b9210f67660295 100644 --- a/ppgan/apps/first_order_predictor.py +++ b/ppgan/apps/first_order_predictor.py @@ -49,7 +49,8 @@ class FirstOrderPredictor(BasePredictor): multi_person=False, image_size=256, face_enhancement=False, - batch_size=1): + batch_size=1, + mobile_net=False): if config is not None and isinstance(config, str): with open(config) as f: self.cfg = yaml.load(f, Loader=yaml.SafeLoader) @@ -87,13 +88,17 @@ class FirstOrderPredictor(BasePredictor): } } } - self.image_size = image_size - if weight_path is None: + self.image_size = image_size + if weight_path is None: + if mobile_net: + vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox_mobile.pdparams' + + else: if self.image_size == 512: vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams' else: vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams' - weight_path = get_path_from_url(vox_cpk_weight_url) + weight_path = get_path_from_url(vox_cpk_weight_url) self.weight_path = weight_path if not os.path.exists(output): diff --git a/ppgan/models/firstorder_model.py b/ppgan/models/firstorder_model.py index 482450d91e27b4484372e4ee6eae106b59445c38..6626f7b994da45d57baf8a818744089e3275bebc 100755 --- a/ppgan/models/firstorder_model.py +++ b/ppgan/models/firstorder_model.py @@ -29,7 +29,7 @@ import numpy as np from paddle.utils import try_import import paddle.nn.functional as F import cv2 - +import os def init_weight(net): def reset_func(m): @@ -186,6 +186,47 @@ class FirstOrderModel(BaseModel): self.nets['kp_detector'].train() self.nets['generator'].train() + class InferGenerator(paddle.nn.Layer): + def set_generator(self, generator): + self.generator = generator + + def forward(self, source, kp_source, kp_driving, kp_driving_initial): + kp_norm = {k: v for k, v in kp_driving.items()} + + kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) + kp_norm['value'] = kp_value_diff + kp_source['value'] + + jacobian_diff = paddle.matmul( + kp_driving['jacobian'], + paddle.inverse(kp_driving_initial['jacobian'])) + kp_norm['jacobian'] = paddle.matmul(jacobian_diff, + kp_source['jacobian']) + out = self.generator(source, kp_source=kp_source, kp_driving=kp_norm) + return out['prediction'] + + + def export_model(self, export_model=None, output_dir=None, inputs_size=[]): + + source = paddle.rand(shape=inputs_size[0], dtype='float32') + driving = paddle.rand(shape=inputs_size[1], dtype='float32') + value = paddle.rand(shape=inputs_size[2], dtype='float32') + j = paddle.rand(shape=inputs_size[3], dtype='float32') + value2 = paddle.rand(shape=inputs_size[2], dtype='float32') + j2 = paddle.rand(shape=inputs_size[3], dtype='float32') + driving1 = {'value': value, 'jacobian': j} + driving2 = {'value': value2, 'jacobian': j2} + driving3 = {'value': value, 'jacobian': j} + + outpath = os.path.join(output_dir, "fom_dy2st") + if not os.path.exists(outpath): + os.makedirs(outpath) + paddle.jit.save(self.nets['Gen_Full'].kp_extractor, os.path.join(outpath, "kp_detector"), input_spec=[source]) + infer_generator = self.InferGenerator() + infer_generator.set_generator(self.nets['Gen_Full'].generator) + paddle.jit.save(infer_generator, os.path.join(outpath, "generator"), input_spec=[source, driving1, driving2, driving3]) + + + class Visualizer: def __init__(self, kp_size=3, draw_border=False, colormap='gist_rainbow'): diff --git a/ppgan/models/generators/occlusion_aware.py b/ppgan/models/generators/occlusion_aware.py index 1ce8aa14c3080e53f050ff918f24c576ba12099a..abf5195948151441145e4d29c694b85d5c6f4748 100644 --- a/ppgan/models/generators/occlusion_aware.py +++ b/ppgan/models/generators/occlusion_aware.py @@ -18,6 +18,7 @@ import paddle from paddle import nn import paddle.nn.functional as F from ...modules.first_order import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, make_coordinate_grid +from ...modules.first_order import MobileResBlock2d, MobileUpBlock2d, MobileDownBlock2d from ...modules.dense_motion import DenseMotionNetwork import numpy as np import cv2 @@ -38,7 +39,8 @@ class OcclusionAwareGenerator(nn.Layer): estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False, - inference=False): + inference=False, + mobile_net=False): super(OcclusionAwareGenerator, self).__init__() if dense_motion_params is not None: @@ -46,45 +48,74 @@ class OcclusionAwareGenerator(nn.Layer): num_kp=num_kp, num_channels=num_channels, estimate_occlusion_map=estimate_occlusion_map, - **dense_motion_params) + **dense_motion_params, mobile_net=mobile_net) else: self.dense_motion_network = None self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), - padding=(3, 3)) + padding=(3, 3), + mobile_net=mobile_net) down_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2**i)) - out_features = min(max_features, block_expansion * (2**(i + 1))) - down_blocks.append( - DownBlock2d(in_features, - out_features, - kernel_size=(3, 3), - padding=(1, 1))) + if mobile_net: + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2**i)) + out_features = min(max_features, block_expansion * (2**(i + 1))) + down_blocks.append( + MobileDownBlock2d(in_features, + out_features, + kernel_size=(3, 3), + padding=(1, 1))) + else: + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2**i)) + out_features = min(max_features, block_expansion * (2**(i + 1))) + down_blocks.append( + DownBlock2d(in_features, + out_features, + kernel_size=(3, 3), + padding=(1, 1))) self.down_blocks = nn.LayerList(down_blocks) up_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, - block_expansion * (2**(num_down_blocks - i))) - out_features = min(max_features, - block_expansion * (2**(num_down_blocks - i - 1))) - up_blocks.append( - UpBlock2d(in_features, - out_features, - kernel_size=(3, 3), - padding=(1, 1))) + if mobile_net: + for i in range(num_down_blocks): + in_features = min(max_features, + block_expansion * (2**(num_down_blocks - i))) + out_features = min(max_features, + block_expansion * (2**(num_down_blocks - i - 1))) + up_blocks.append( + MobileUpBlock2d(in_features, + out_features, + kernel_size=(3, 3), + padding=(1, 1))) + else: + for i in range(num_down_blocks): + in_features = min(max_features, + block_expansion * (2**(num_down_blocks - i))) + out_features = min(max_features, + block_expansion * (2**(num_down_blocks - i - 1))) + up_blocks.append( + UpBlock2d(in_features, + out_features, + kernel_size=(3, 3), + padding=(1, 1))) self.up_blocks = nn.LayerList(up_blocks) self.bottleneck = paddle.nn.Sequential() in_features = min(max_features, block_expansion * (2**num_down_blocks)) - for i in range(num_bottleneck_blocks): - self.bottleneck.add_sublayer( - 'r' + str(i), - ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) + if mobile_net: + for i in range(num_bottleneck_blocks): + self.bottleneck.add_sublayer( + 'r' + str(i), + MobileResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) + else: + for i in range(num_bottleneck_blocks): + self.bottleneck.add_sublayer( + 'r' + str(i), + ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) self.final = nn.Conv2D(block_expansion, num_channels, diff --git a/ppgan/modules/dense_motion.py b/ppgan/modules/dense_motion.py index 2ec0d789e7b974839385af9181b72decf6a216d0..54222131b6b8529e772735cdd82e0a08371b2ca3 100644 --- a/ppgan/modules/dense_motion.py +++ b/ppgan/modules/dense_motion.py @@ -33,13 +33,15 @@ class DenseMotionNetwork(nn.Layer): num_channels, estimate_occlusion_map=False, scale_factor=1, - kp_variance=0.01): + kp_variance=0.01, + mobile_net=False): super(DenseMotionNetwork, self).__init__() self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), max_features=max_features, - num_blocks=num_blocks) + num_blocks=num_blocks, + mobile_net=mobile_net) self.mask = nn.Conv2D(self.hourglass.out_filters, num_kp + 1, @@ -155,10 +157,10 @@ class DenseMotionNetwork(nn.Layer): source_image, sparse_motion) out_dict['sparse_deformed'] = deformed_source - input = paddle.concat([heatmap_representation, deformed_source], axis=2) - input = input.reshape([bs, -1, h, w]) + temp = paddle.concat([heatmap_representation, deformed_source], axis=2) + temp = temp.reshape([bs, -1, h, w]) - prediction = self.hourglass(input) + prediction = self.hourglass(temp) mask = self.mask(prediction) mask = F.softmax(mask, axis=1) diff --git a/ppgan/modules/first_order.py b/ppgan/modules/first_order.py index 11d3f83c1b25b1474cebaa18f2445065d4c6fd21..af7c1241eee7b63feec98ff17dda49a433ad6baf 100644 --- a/ppgan/modules/first_order.py +++ b/ppgan/modules/first_order.py @@ -24,7 +24,7 @@ def SyncBatchNorm(*args, **kwargs): if paddle.get_device() == 'cpu': return nn.BatchNorm(*args, **kwargs) else: - return nn.SyncBatchNorm(*args, **kwargs) + return nn.BatchNorm(*args, **kwargs) class ImagePyramide(nn.Layer): @@ -123,6 +123,40 @@ class ResBlock2d(nn.Layer): out += x return out +class MobileResBlock2d(nn.Layer): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(MobileResBlock2d, self).__init__() + out_features = in_features * 2 + self.conv_pw = nn.Conv2D(in_channels=in_features, out_channels=out_features, kernel_size=1, + padding=0, bias_attr=False) + self.conv_dw = nn.Conv2D(in_channels=out_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=out_features, bias_attr=False) + self.conv_pw_linear = nn.Conv2D(in_channels=out_features, out_channels=in_features, kernel_size=1, + padding=0, bias_attr=False) + self.norm1 = SyncBatchNorm(in_features) + self.norm_pw = SyncBatchNorm(out_features) + self.norm_dw = SyncBatchNorm(out_features) + self.norm_pw_linear = SyncBatchNorm(in_features) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv_pw(out) + out = self.norm_pw(out) + + out = self.conv_dw(out) + out = self.norm_dw(out) + out = F.relu(out) + + out = self.conv_pw_linear(out) + out = self.norm_pw_linear(out) + out += x + return out + class UpBlock2d(nn.Layer): """ @@ -150,6 +184,32 @@ class UpBlock2d(nn.Layer): out = F.relu(out) return out +class MobileUpBlock2d(nn.Layer): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(MobileUpBlock2d, self).__init__() + + self.conv = nn.Conv2D(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding, groups=in_features, bias_attr=False) + self.conv1 = nn.Conv2D(in_channels=in_features, out_channels=out_features, kernel_size=1, + padding=0, bias_attr=False) + self.norm = SyncBatchNorm(in_features) + self.norm1 = SyncBatchNorm(out_features) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + out = self.conv1(out) + out = self.norm1(out) + out = F.relu(out) + return out + + class DownBlock2d(nn.Layer): """ @@ -178,6 +238,33 @@ class DownBlock2d(nn.Layer): return out +class MobileDownBlock2d(nn.Layer): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(MobileDownBlock2d, self).__init__() + self.conv = nn.Conv2D(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding, groups=in_features, bias_attr=False) + self.norm = SyncBatchNorm(in_features) + self.pool = nn.AvgPool2D(kernel_size=(2, 2)) + + self.conv1 = nn.Conv2D(in_features, out_features, kernel_size=1, padding=0, stride=1, bias_attr=False) + self.norm1 = SyncBatchNorm(out_features) + + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.conv1(out) + out = self.norm1(out) + out = F.relu(out) + out = self.pool(out) + return out + + class SameBlock2d(nn.Layer): """ Simple block, preserve spatial resolution. @@ -187,13 +274,15 @@ class SameBlock2d(nn.Layer): out_features, groups=1, kernel_size=3, - padding=1): + padding=1, + mobile_net=False): super(SameBlock2d, self).__init__() self.conv = nn.Conv2D(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, - groups=groups) + groups=groups, + bias_attr=(mobile_net==False)) self.norm = SyncBatchNorm(out_features) def forward(self, x): @@ -211,17 +300,25 @@ class Encoder(nn.Layer): block_expansion, in_features, num_blocks=3, - max_features=256): + max_features=256, + mobile_net = False): super(Encoder, self).__init__() down_blocks = [] for i in range(num_blocks): - down_blocks.append( - DownBlock2d(in_features if i == 0 else min( - max_features, block_expansion * (2**i)), - min(max_features, block_expansion * (2**(i + 1))), - kernel_size=3, - padding=1)) + if mobile_net: + down_blocks.append( + MobileDownBlock2d(in_features if i == 0 else min( + max_features, block_expansion * (2**i)), + min(max_features, block_expansion * (2**(i + 1))), + kernel_size=3, padding=1)) + else: + down_blocks.append( + DownBlock2d(in_features if i == 0 else min( + max_features, block_expansion * (2**i)), + min(max_features, block_expansion * (2**(i + 1))), + kernel_size=3, + padding=1)) self.down_blocks = nn.LayerList(down_blocks) def forward(self, x): @@ -239,17 +336,24 @@ class Decoder(nn.Layer): block_expansion, in_features, num_blocks=3, - max_features=256): + max_features=256, + mobile_net = False): super(Decoder, self).__init__() up_blocks = [] for i in range(num_blocks)[::-1]: - in_filters = (1 if i == num_blocks - 1 else 2) * min( - max_features, block_expansion * (2**(i + 1))) out_filters = min(max_features, block_expansion * (2**i)) - up_blocks.append( - UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) + if mobile_net: + in_filters = (1 if i == num_blocks - 1 else 2) * min( + max_features, block_expansion * (2**(i + 1))) + up_blocks.append( + MobileUpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) + else: + in_filters = (1 if i == num_blocks - 1 else 2) * min( + max_features, block_expansion * (2**(i + 1))) + up_blocks.append( + UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) self.up_blocks = nn.LayerList(up_blocks) self.out_filters = block_expansion + in_features @@ -271,12 +375,13 @@ class Hourglass(nn.Layer): block_expansion, in_features, num_blocks=3, - max_features=256): + max_features=256, + mobile_net=False): super(Hourglass, self).__init__() self.encoder = Encoder(block_expansion, in_features, num_blocks, - max_features) + max_features, mobile_net=mobile_net) self.decoder = Decoder(block_expansion, in_features, num_blocks, - max_features) + max_features, mobile_net=mobile_net) self.out_filters = self.decoder.out_filters def forward(self, x): @@ -325,7 +430,7 @@ class AntiAliasInterpolation2d(nn.Layer): inv_scale = 1 / self.scale int_inv_scale = int(inv_scale) assert (inv_scale == int_inv_scale) - out = out[:, :, ::int_inv_scale, ::int_inv_scale] + #out = out[:, :, ::int_inv_scale, ::int_inv_scale] # patch end - + out = paddle.fluid.layers.resize_nearest(out, scale=self.scale) return out diff --git a/ppgan/modules/keypoint_detector.py b/ppgan/modules/keypoint_detector.py index b31c0f763666e94be796e409bcde247b9708cf99..60d4e02224be830c537114aa7dbeb15a23fa69a7 100644 --- a/ppgan/modules/keypoint_detector.py +++ b/ppgan/modules/keypoint_detector.py @@ -34,13 +34,15 @@ class KPDetector(nn.Layer): estimate_jacobian=False, scale_factor=1, single_jacobian_map=False, - pad=0): + pad=0, + mobile_net=False): super(KPDetector, self).__init__() self.predictor = Hourglass(block_expansion, in_features=num_channels, max_features=max_features, - num_blocks=num_blocks) + num_blocks=num_blocks, + mobile_net=mobile_net) self.kp = nn.Conv2D(in_channels=self.predictor.out_filters, out_channels=num_kp,