From 3a026b6a7ff1b66a76ce0bcd9e56644dd29eee6e Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 1 Jul 2022 10:45:54 +0800 Subject: [PATCH] Fix eval function in segmentation demo of ACT (#1218) --- .../semantic_segmentation/README.md | 10 ++ .../semantic_segmentation/run.py | 10 +- paddleslim/utils/__init__.py | 15 ++ paddleslim/utils/download.py | 163 ++++++++++++++++++ 4 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 paddleslim/utils/__init__.py create mode 100644 paddleslim/utils/download.py diff --git a/demo/auto_compression/semantic_segmentation/README.md b/demo/auto_compression/semantic_segmentation/README.md index 092fe398..d6767627 100644 --- a/demo/auto_compression/semantic_segmentation/README.md +++ b/demo/auto_compression/semantic_segmentation/README.md @@ -44,10 +44,12 @@ - PP-HumanSeg-Lite数据集 - 数据集:AISegment + PP-HumanSeg14K + 内部自建数据集。其中 AISegment 是开源数据集,可从[链接](https://github.com/aisegmentcn/matting_human_datasets)处获取;PP-HumanSeg14K 是 PaddleSeg 自建数据集,可从[官方渠道](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/contrib/PP-HumanSeg/paper.md#pp-humanseg14k-a-large-scale-teleconferencing-video-dataset)获取;内部数据集不对外公开。 + - 示例数据集: 用于快速跑通人像分割的压缩和推理流程, 不能用该数据集复现 benckmark 表中的压缩效果。 [下载链接](https://paddleseg.bj.bcebos.com/humanseg/data/mini_supervisely.zip) - PP-Liteseg,HRNet,UNet,Deeplabv3-ResNet50数据集 - cityscapes: 请从[cityscapes官网](https://www.cityscapes-dataset.com/login/)下载完整数据 + - 示例数据集: cityscapes数据集的一个子集,用于快速跑通压缩和推理流程,不能用该数据集复现 benchmark 表中的压缩效果。[下载链接](https://bj.bcebos.com/v1/paddle-slim-models/data/mini_cityscapes/mini_cityscapes.tar) 下面将以开源数据集为例介绍如何对PP-HumanSeg-Lite进行自动压缩。 @@ -85,6 +87,14 @@ pip install paddleseg 开发者可下载开源数据集 (如[AISegment](https://github.com/aisegmentcn/matting_human_datasets)) 或自定义语义分割数据集。请参考[PaddleSeg数据准备文档](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/data/marker/marker_cn.md)来检查、对齐数据格式即可。 +可以通过以下命令下载人像分割示例数据: + +```shell +cd ./data +python download_data.py mini_humanseg + +``` + #### 3.3 准备预测模型 预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 diff --git a/demo/auto_compression/semantic_segmentation/run.py b/demo/auto_compression/semantic_segmentation/run.py index 9a6eea1b..a5d0be12 100644 --- a/demo/auto_compression/semantic_segmentation/run.py +++ b/demo/auto_compression/semantic_segmentation/run.py @@ -105,9 +105,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ori_shape, eval_dataset.transforms.transforms, mode='bilinear') - - pred = paddle.argmax( - paddle.to_tensor(logit), axis=1, keepdim=True, dtype='int32') + pred = paddle.to_tensor(logit) + if len( + pred.shape + ) == 4: # for humanseg model whose prediction is distribution but not class id + pred = paddle.argmax(pred, axis=1, keepdim=True, dtype='int32') intersect_area, pred_area, label_area = metrics.calculate_area( pred, @@ -166,7 +168,7 @@ def reader_wrapper(reader): if __name__ == '__main__': args = parse_args() - + paddle.enable_static() # step1: load dataset config and create dataloader data_cfg = PaddleSegDataConfig(args.dataset_config) train_dataset = data_cfg.train_dataset diff --git a/paddleslim/utils/__init__.py b/paddleslim/utils/__init__.py new file mode 100644 index 00000000..af05cf20 --- /dev/null +++ b/paddleslim/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import download diff --git a/paddleslim/utils/download.py b/paddleslim/utils/download.py new file mode 100644 index 00000000..98d9738d --- /dev/null +++ b/paddleslim/utils/download.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import shutil +import sys +import tarfile +import time +import zipfile + +import requests + +lasttime = time.time() +FLUSH_INTERVAL = 0.1 + + +def progress(str, end=False): + global lasttime + if end: + str += "\n" + lasttime = 0 + if time.time() - lasttime >= FLUSH_INTERVAL: + sys.stdout.write("\r%s" % str) + lasttime = time.time() + sys.stdout.flush() + + +def _download_file(url, savepath, print_progress): + if print_progress: + print("Connecting to {}".format(url)) + r = requests.get(url, stream=True, timeout=15) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(savepath, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(savepath, 'wb') as f: + dl = 0 + total_length = int(total_length) + starttime = time.time() + if print_progress: + print("Downloading %s" % os.path.basename(savepath)) + for data in r.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + if print_progress: + done = int(50 * dl / total_length) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * dl) / total_length)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + +def _uncompress_file_zip(filepath, extrapath): + files = zipfile.ZipFile(filepath, 'r') + filelist = files.namelist() + rootpath = filelist[0] + total_num = len(filelist) + for index, file in enumerate(filelist): + files.extract(file, extrapath) + yield total_num, index, rootpath + files.close() + yield total_num, index, rootpath + + +def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): + files = tarfile.open(filepath, mode) + filelist = files.getnames() + total_num = len(filelist) + rootpath = filelist[0] + for index, file in enumerate(filelist): + files.extract(file, extrapath) + yield total_num, index, rootpath + files.close() + yield total_num, index, rootpath + + +def _uncompress_file(filepath, extrapath, delete_file, print_progress): + if print_progress: + print("Uncompress %s" % os.path.basename(filepath)) + + if filepath.endswith("zip"): + handler = _uncompress_file_zip + elif filepath.endswith("tgz"): + handler = functools.partial(_uncompress_file_tar, mode="r:*") + else: + handler = functools.partial(_uncompress_file_tar, mode="r") + + for total_num, index, rootpath in handler(filepath, extrapath): + if print_progress: + done = int(50 * float(index) / total_num) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * index) / total_num)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + if delete_file: + os.remove(filepath) + + return rootpath + + +def download_file_and_uncompress(url, + savepath=None, + extrapath=None, + extraname=None, + print_progress=True, + cover=False, + delete_file=True): + if savepath is None: + savepath = "." + + if extrapath is None: + extrapath = "." + + savename = url.split("/")[-1] + if not os.path.exists(savepath): + os.makedirs(savepath) + + savepath = os.path.join(savepath, savename) + savename = ".".join(savename.split(".")[:-1]) + savename = os.path.join(extrapath, savename) + extraname = savename if extraname is None else os.path.join(extrapath, + extraname) + + if cover: + if os.path.exists(savepath): + shutil.rmtree(savepath) + if os.path.exists(savename): + shutil.rmtree(savename) + if os.path.exists(extraname): + shutil.rmtree(extraname) + + if not os.path.exists(extraname): + if not os.path.exists(savename): + if not os.path.exists(savepath): + _download_file(url, savepath, print_progress) + + if (not tarfile.is_tarfile(savepath)) and ( + not zipfile.is_zipfile(savepath)): + if not os.path.exists(extraname): + os.makedirs(extraname) + shutil.move(savepath, extraname) + return extraname + + savename = _uncompress_file(savepath, extrapath, delete_file, + print_progress) + savename = os.path.join(extrapath, savename) + shutil.move(savename, extraname) + return extraname -- GitLab