+git clone https://github.com/jerrywgz/PaddleModelPipeline.git
+
+# 安装其他依赖
+cd PaddleModelPipeline
+
+# 编译安装paddlecv
+python setup.py install
+```
+
+同时支持whl包安装使用,详细步骤参考[文档](whl.md)
+
+安装后确认测试通过:
+
+```
+python tests/test_pipeline.py
+```
+
+测试通过后会提示如下信息:
+
+```
+.
+----------------------------------------------------------------------
+Ran 1 tests in 2.967s
+
+OK
+```
+
+## 快速体验
+
+**恭喜!** 您已经成功安装了PaddleCV,接下来快速体验目标检测效果
+
+```
+# 在GPU上预测一张图片
+export CUDA_VISIBLE_DEVICES=0
+python -u tools/predict.py --config=configs/single_op/PP-YOLOE+.yml --input=demo/000000014439.jpg
+```
+
+会在`output`文件夹下生成一个画有预测结果的同名图像。
+
+结果如下图:
+
+
diff --git a/paddlecv/docs/config_anno.md b/paddlecv/docs/config_anno.md
new file mode 100644
index 0000000000000000000000000000000000000000..6df948f679a6cc2af2e25213d955cb39c6d07093
--- /dev/null
+++ b/paddlecv/docs/config_anno.md
@@ -0,0 +1,55 @@
+# 配置文件说明
+
+本文档以目标检测模型[PP-YOLOE+](../configs/single_op/PP-YOLOE+.yml)为例,具体说明配置文件各字段含义
+
+## 环境配置段
+
+```
+ENV:
+ min_subgraph_size: 3 # TensorRT最小子图大小
+ shape_info_filename: ./ # TensorRT shape收集文件路径
+ trt_calib_mode: False # 如果设置TensorRT离线量化校准,需要设置为True
+ cpu_threads: 1 # CPU部署时线程数
+ trt_use_static: False # TensorRT部署是否加载预生成的engine文件
+ save_img: True # 是否保存可视化图片,默认路径在output文件夹下
+ save_res: True # 是否保存结构化输出,默认路径在output文件夹下
+ return_res: True # 是否返回全量结构化输出结果
+```
+
+## 模型配置段
+
+```
+MODEL:
+ - DetectionOp: # 模型算子类名,输出字段为固定值,即["dt_bboxes", "dt_scores", "dt_class_ids", "dt_cls_names"]
+ name: det # 算子名称,单个配置文件中不同算子名称不能重复
+ param_path: paddlecv://models/ppyoloe_plus_crn_l_80e_coco/model.pdiparams # 推理模型参数文件,支持本地地址,也支持线上链接并自动下载
+ model_path: paddlecv://models/ppyoloe_plus_crn_l_80e_coco/model.pdmodel # 推理模型文件,支持本地地址,也支持线上链接并自动下载
+ batch_size: 1 # batch size
+ image_shape: [3, *image_shape, *image_shape] # 网络输入shape
+ PreProcess: # 预处理模块,集中在ppcv/ops/models/detection/preprocess.py中
+ - Resize:
+ interp: 2
+ keep_ratio: false
+ target_size: [*image_shape, *image_shape]
+ - NormalizeImage:
+ is_scale: true
+ mean: [0., 0., 0.]
+ std: [1., 1., 1.]
+ norm_type: null
+ - Permute:
+ PostProcess: #后处理模块,集中在ppcv/ops/models/detection/postprocess.py
+ - ParserDetResults:
+ label_list: paddlecv://dict/detection/coco_label_list.json
+ threshold: 0.5
+ Inputs: # 输入字段,DetectionOp算子输入所需字段,格式为{上一个op名}.{上一个op输出字段},第一个Op的上一个op名为input
+ - input.image
+
+ - DetOutput: # 输出模型算子类名
+ name: vis # 算子名称,单个配置文件中不同算子名称不能重复
+ Inputs: # 输入字段,DetOutput算子输入所需字段,格式为{上一个op名}.{上一个op输出字段}
+ - input.fn
+ - input.image
+ - det.dt_bboxes
+ - det.dt_scores
+ - det.dt_cls_names
+```
diff --git a/paddlecv/docs/custom_ops.md b/paddlecv/docs/custom_ops.md
new file mode 100644
index 0000000000000000000000000000000000000000..f4a53e944eaec7be2bb47ec2bd7da7a88f5c3395
--- /dev/null
+++ b/paddlecv/docs/custom_ops.md
@@ -0,0 +1,60 @@
+# 外部算子开发
+
+- [简介](#1)
+- [外部算子依赖](#2)
+- [外部算子实现方式](#3)
+
+
+
+
+## 1. 简介
+
+本教程主要介绍基于paddlecv新增外部算子,实现定制化算子开发,进行外部算子开发前,首先准备paddlecv环境,推荐使用pip安装
+
+```bash
+pip install paddlecv
+```
+
+
+
+## 2. 外部算子依赖
+
+外部算子主要依赖接口如下:
+
+#### 1)`ppcv.ops.base.create_operators(params, mod)`
+
+ - 功能:创建预处理后处理算子接口
+ - 输入:
+ - params: 前后后处理配置字典
+ - mod: 当前算子module
+ - 输出:前后处理算子实例化对象列表
+
+
+
+#### 2)算子BaseOp
+
+外部算子类型和paddlecv内算子类型相同,分为模型算子、衔接算子和输出算子。新增外部算子需要继承每类算子对应的BaseOp,对应关系如下:
+
+ ```txt
+ 模型算子:ppcv.ops.models.base.ModelBaseOp
+ 衔接算子:ppcv.ops.connector.base.ConnectorBaseOp
+ 输出算子:ppcv.ops.output.base.OutputBaseOp
+ ```
+
+#### 3)ppcv.core.workspace.register
+
+需要使用@register对每个外部算子类进行修饰,例如:
+
+```python
+from ppcv.ops.models.base import ModelBaseOp
+from ppcv.core.workspace import register
+
+@register
+class DetectionCustomOp(ModelBaseOp)
+```
+
+
+
+## 3. 外部算子实现方式
+
+可直接参考[新增算子文档](how_to_add_new_op.md),实现后使用方式与paddlecv内部提供算子相同。paddlecv中提供检测外部算子[示例](../custom_op)
diff --git a/paddlecv/docs/how_to_add_new_op.md b/paddlecv/docs/how_to_add_new_op.md
new file mode 100644
index 0000000000000000000000000000000000000000..3ed7edec693edcec8108e5723ba53468d983e99f
--- /dev/null
+++ b/paddlecv/docs/how_to_add_new_op.md
@@ -0,0 +1,111 @@
+# 新增算子
+
+## 1. 简介
+
+本教程主要介绍怎样基于PaddleCV新增推理算子。
+
+本项目中,算子主要分为3个部分。
+
+- 模型推理算子:给定输入,加载模型,完成预处理、推理、后处理,返回输出。
+- 模型衔接算子:给定输入,计算得到输出。一般用于将一个模型的输出处理为另外一个模型的输入,比如说目标检测、文本检测的扣图、方向矫正模块之后的模型旋转、文本合成等操作。
+- 模型输出算子:存储、可视化、输出模型的输出结果。
+
+在下面的介绍中,我们把算子称为op。
+
+
+## 2. 单个op的输入/输出格式
+
+PaddleCV的输入为图像或者视频。
+
+对于所有的op,系统会将其整理为`a list of dict`的格式。列表中的每个元素均为一个待推理的对象及其中间结果。比如,对于图像分类来说,其输入仅包含图像信息,输入格式如下所示。
+
+
+```json
+[
+ {"image": img1},
+ {"image": img2},
+]
+```
+
+
+输出格式为
+
+```json
+[
+ {"image": img1, "class_ids": class_id1, "scores": scores1, "label_names": label_names1},
+ {"image": img2, "class_ids": class_id2, "scores": scores2, "label_names": label_names2},
+]
+```
+
+同理,对于模型衔接算子(BBoxCropOp为例)来说,其输入如下。
+
+```json
+[
+ {"image": img1, "bbox": bboxes1},
+ {"image": img2, "bbox": bboxes2},
+]
+```
+
+
+## 3. 新增算子
+
+### 3.1 模型推理算子
+
+模型推理算子,整体继承自[ModelBaseOp类](../ppcv/ops/models/base.py)。示例可参考图像分类op:[ClassificationOp类](../ppcv/ops/models/classification/inference.py)。具体地,我们需要实现以下几个内容。
+
+(1)该类需要继承自`ModelBaseOp`,同时使用`@register`方法进行注册,保证全局唯一。
+
+(2)实现类中一些方法,包括
+
+- 初始化`__init__`
+ - 输入:model_cfg与env_cfg
+ - 输出:无
+- 模型预处理`preprocess`
+ - 输入:基于input_keys过滤后的模型输入
+ - 输出:模型预处理结果
+- 模型后处理`postprocess`
+ - 输入:模型推理结果
+ - 输出:模型后处理结果
+- 预测`__call__`
+ - 输入:该op依赖的输入内容
+ - 输出:该op的处理结果
+
+
+
+### 3.2 模型衔接算子
+
+
+模型衔接算子,整体继承自[ConnectorBaseOp](../ppcv/ops/connector/base.py)。示例可参考方向矫正op:[ClsCorrectionOp类](../ppcv/ops/connector/op_connector.py)。具体地,我们需要实现以下几个内容。
+
+(1)该类需要继承自`ConnectorBaseOp`,同时使用`@register`方法进行注册,保证全局唯一。
+
+(2)实现类中一些方法,包括
+
+- 初始化`__init__`
+ - 输入:model_cfg、env_cfg(一般为None)
+ - 输出:无
+- 调用`__call__`
+ - 输入:该op依赖的输入内容
+ - 输出:该op的处理结果
+
+
+### 3.3 模型输出算子
+
+
+模型衔接算子,整体继承自[OutputBaseOp](../ppcv/ops/output/base.py)。示例可参考方向矫正op:[ClasOutput类](../ppcv/ops/output/classification.py)。具体地,我们需要实现以下几个内容。
+
+(1)该类需要继承自`OutputBaseOp`,同时使用`@register`方法进行注册,保证全局唯一。
+
+(2)实现类中一些方法,包括
+
+- 初始化`__init__`
+ - 输入:model_cfg、env_cfg(一般为None)
+ - 输出:无
+- 调用`__call__`
+ - 输入:模型输出
+ - 输出:返回结果
+
+
+## 4. 新增单测
+
+在新增op之后,需要新增基于该op的单测,可以参考[test_classification.py](../tests/test_classification.py)。
diff --git a/paddlecv/docs/images/pipeline.png b/paddlecv/docs/images/pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..e0c0720352ea28a412b19b129d37346ee2c790c8
Binary files /dev/null and b/paddlecv/docs/images/pipeline.png differ
diff --git a/paddlecv/docs/system_design.md b/paddlecv/docs/system_design.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf71f4c929f0cf55017220201f2e1e33e97d146b
--- /dev/null
+++ b/paddlecv/docs/system_design.md
@@ -0,0 +1,62 @@
+# 系统设计思想
+
+- [目标](#1)
+- [框架设计](#2)
+ - [2.1 配置模块设计](#2.1)
+ - [2.2 输入模块设计](#2.2)
+ - [2.3 算子实现方案](#2.3)
+ - [2.4 系统串联方案](#2.4)
+
+
+
+
+
+## 目标
+为了解决深度学习单模型及串联系统部署问题,飞桨模型团队设计了一套通用统一的部署系统,其核心特点包括:
+
+1. 通用性:系统既要满足单模型部署,又要支持多模型复杂的拓扑关系
+2. 高可用性:支持多种不同输入类型,通过配置文件即可高效实现复现系统串联
+4. 高灵活性:支持自定义算子便捷接入,灵活实现定制化部署需求
+
+
+
+## 框架设计
+
+系统整体架构如图所示
+
+
+

+
+
+
+
+**一. 配置模块设计**
+
+配置模块解析配置文件,拆分为环境配置和模型配置,同时检查配置项是否合规。环境配置负责管理部署环境相关配置,例如`run_mode`和`device`等。模型配置负责管理每个模型算子配置,包括模型路径、前后处理等。通过`Inputs`配置段实现模型间复杂的串联关系。配置示例可以参考[PP-PicoDet.yml](../configs/single_op/PP-PicoDet.yml)
+
+同时支持命令行更新配置文件任意配置项功能,利于开发者快速进行更改环境,替换模型,超参调优等工作。
+
+配置文件管理部分,系统针对每个任务(task),推荐用户使用对应的配置文件,并提供`get_config_file`接口实现自动下载, 例如:
+
+```python
+import paddlecv
+paddlecv.get_config_file('detection')
+```
+
+
+
+**二. 输入模块设计**
+
+输入模块解析输入文件格式,支持图片,图片文件夹,视频,numpy数据格式。统一使用`input`字段作为输入接口。输入模块代码实现参考[链接](../ppcv/engine/pipeline.py#L45)
+
+
+
+**三. 算子实现方案**
+
+系统算子分为模型算子(MODEL)、衔接算子(CONNECTOR)和输出算子(OUTPUT)三部分。三部分算子均有固定的输出格式和输出字段。模型算子将每个模型的预处理、前向推理、后处理端到端全流程进行独立封装;衔接算子连接模型算子的各类输入输出,例如扣图、过滤等;输出字段负责单模型或复杂系统的输出形式,例如可视化、结果保存等功能。详细算子实现流程请参考[文档](how_to_add_new_op.md)
+
+
+
+**四. 系统串联方案**
+
+系统通过有向无环图(DAG)串联各个算子并执行,每个算子需要指定`Inputs`字段,字段格式为`{last_op_name}.{last_op_output_name}`,即需要包含前置算子名称和对应输出字段名。从而建立算子之间的拓扑关系,并通过拓扑排序的方式决定算子执行顺序。系统串联执行过程中,会维护全量输出结果,并根据算子指定的`Inputs`字段对结果进行过滤,保证各算子内部计算独立。执行器核心代码实现参考[链接](../ppcv/core/framework.py#L92)
diff --git a/paddlecv/docs/whl.md b/paddlecv/docs/whl.md
new file mode 100644
index 0000000000000000000000000000000000000000..87ec313bc5e5bd1d04d5fa59a015054960f5dfe1
--- /dev/null
+++ b/paddlecv/docs/whl.md
@@ -0,0 +1,45 @@
+# Whl包使用
+
+## 1. 安装与简介
+
+目前该whl包尚未上传至pypi,因此目前需要通过以下方式安装。
+
+```shell
+python setup.py bdist_wheel
+pip install dist/paddlecv-0.1.0-py3-none-any.whl
+```
+
+## 2. 基本调用
+
+
+使用方式如下所示。
+
+* 可以指定task_name或者config_path,来获取所需要预测的系统。当使用`task_name`时,会从PaddleCV项目中获取已经自带的模型或者串联系统,进行预测,而使用`config_path`时,则会加载配置文件,完成模型或者串联系统的初始化。
+s
+```py
+from paddlecv import PaddleCV
+paddlecv = PaddleCV(task_name="PP-OCRv3")
+res = paddlecv("../demo/00056221.jpg")
+```
+
+* 如果希望查看系统自带的的串联系统列表,可以使用下面的方式。
+
+```py
+from paddlecv import PaddleCV
+PaddleCV.list_all_supported_tasks()
+```
+
+输出内容如下。
+
+```
+[11/17 06:17:20] ppcv INFO: Tasks and recommanded configs that paddlecv supports are :
+PP-Human: paddlecv://configs/system/PP-Human.yml
+PP-OCRv2: paddlecv://configs/system/PP-OCRv2.yml
+PP-OCRv3: paddlecv://configs/system/PP-OCRv3.yml
+...
+```
+
+
+## 3. 高阶开发
+
+如果你希望优化paddlecv whl包接口,可以修改`paddlecv.py`文件,然后重新编译生成whl包即可。
diff --git a/paddlecv/paddlecv.py b/paddlecv/paddlecv.py
new file mode 100644
index 0000000000000000000000000000000000000000..49ba871bcb785e723181da11ca6d85faa542a1ca
--- /dev/null
+++ b/paddlecv/paddlecv.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import importlib
+import argparse
+
+__dir__ = os.path.dirname(__file__)
+
+sys.path.insert(0, os.path.join(__dir__, ''))
+
+import cv2
+import logging
+import numpy as np
+from pathlib import Path
+
+ppcv = importlib.import_module('.', 'ppcv')
+tools = importlib.import_module('.', 'tools')
+tests = importlib.import_module('.', 'tests')
+
+VERSION = '0.1.0'
+
+import yaml
+from ppcv.model_zoo.model_zoo import TASK_DICT, list_model, get_config_file
+from ppcv.engine.pipeline import Pipeline
+from ppcv.utils.logger import setup_logger
+
+logger = setup_logger()
+
+
+class PaddleCV(object):
+ def __init__(self,
+ task_name=None,
+ config_path=None,
+ output_dir=None,
+ run_mode='paddle',
+ device='CPU'):
+
+ if task_name is not None:
+ assert task_name in TASK_DICT, f"task_name must be one of {list(TASK_DICT.keys())} but got {task_name}"
+ config_path = get_config_file(task_name)
+ else:
+ assert config_path is not None, "task_name and config_path can not be None at the same time!!!"
+
+ self.cfg_dict = dict(
+ config=config_path,
+ output_dir=output_dir,
+ run_mode=run_mode,
+ device=device)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ self.pipeline = Pipeline(cfg)
+
+ @classmethod
+ def list_all_supported_tasks(self, ):
+ logger.info(
+ f"Tasks and recommanded configs that paddlecv supports are : ")
+ buffer = yaml.dump(TASK_DICT)
+ print(buffer)
+ return
+
+ @classmethod
+ def list_all_supported_models(self, filters=[]):
+ list_model(filters)
+ return
+
+ def __call__(self, input):
+ res = self.pipeline.run(input)
+ return res
diff --git a/paddlecv/ppcv/__init__.py b/paddlecv/ppcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25805c5daaf7ff045f7dcff16bdea18dbd09ef2
--- /dev/null
+++ b/paddlecv/ppcv/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import (core, engine, ops, utils, model_zoo)
diff --git a/paddlecv/ppcv/core/__init__.py b/paddlecv/ppcv/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd903e54904d3f1f33b1abc0ffb99d4b3057cb9f
--- /dev/null
+++ b/paddlecv/ppcv/core/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import workspace
+from .workspace import *
+
+__all__ = workspace.__all__
diff --git a/paddlecv/ppcv/core/config.py b/paddlecv/ppcv/core/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed301fa1f3696eb702a750c7c16c7d65222fb8f1
--- /dev/null
+++ b/paddlecv/ppcv/core/config.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import paddle
+import collections
+from collections import defaultdict
+from collections.abc import Sequence, Mapping
+import yaml
+import copy
+from argparse import ArgumentParser, RawDescriptionHelpFormatter
+from ppcv.utils.logger import setup_logger
+import ppcv
+from ppcv.ops import *
+
+logger = setup_logger('config')
+
+
+class ArgsParser(ArgumentParser):
+ def __init__(self):
+ super(ArgsParser, self).__init__(
+ formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument(
+ "-o", "--opt", nargs='*', help="set configuration options")
+
+ def parse_args(self, argv=None):
+ args = super(ArgsParser, self).parse_args(argv)
+ assert args.config is not None, \
+ "Please specify --config=configure_file_path."
+ args.opt = self._parse_opt(args.opt)
+ return args
+
+ def _parse_opt(self, opts):
+ config = {}
+ if not opts:
+ return config
+ for s in opts:
+ s = s.strip()
+ k, v = s.split('=', 1)
+ if '.' not in k:
+ config[k] = yaml.load(v, Loader=yaml.Loader)
+ else:
+ keys = k.split('.')
+ if keys[0] not in config:
+ config[keys[0]] = {}
+ cur = config[keys[0]]
+ for idx, key in enumerate(keys[1:]):
+ if idx == len(keys) - 2:
+ cur[key] = yaml.load(v, Loader=yaml.Loader)
+ else:
+ cur[key] = {}
+ cur = cur[key]
+ return config
+
+
+class ConfigParser(object):
+ def __init__(self, args):
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+ print('args: ', args)
+ self.model_cfg, self.env_cfg = self.merge_cfg(args, cfg)
+ self.check_cfg()
+
+ def merge_cfg(self, args, cfg):
+ env_cfg = cfg['ENV']
+ model_cfg = cfg['MODEL']
+
+ def merge(cfg, arg):
+ merge_cfg = copy.deepcopy(cfg)
+ for k, v in cfg.items():
+ if k in arg:
+ merge_cfg[k] = arg[k]
+ else:
+ if isinstance(v, dict):
+ merge_cfg[k] = merge(v, arg)
+ return merge_cfg
+
+ def merge_opt(cfg, arg):
+ for k, v in arg.items():
+ if isinstance(cfg, Sequence):
+ k = eval(k)
+ cfg[k] = merge_opt(cfg[k], v)
+ else:
+ if (k in cfg and (isinstance(cfg[k], Sequence) or
+ isinstance(cfg[k], Mapping)) and
+ isinstance(arg[k], Mapping)):
+ merge_opt(cfg[k], arg[k])
+ else:
+ cfg[k] = arg[k]
+ return cfg
+
+ args_dict = vars(args)
+ for k, v in args_dict.items():
+ if k not in env_cfg:
+ env_cfg[k] = v
+ env_cfg = merge(env_cfg, args_dict)
+ print('debug env_cfg: ', env_cfg)
+ if 'opt' in args_dict.keys() and args_dict['opt']:
+ opt_dict = args_dict['opt']
+ if opt_dict.get('ENV', None):
+ env_cfg = merge_opt(env_cfg, opt_dict['ENV'])
+ if opt_dict.get('MODEL', None):
+ model_cfg = merge_opt(model_cfg, opt_dict['MODEL'])
+
+ return model_cfg, env_cfg
+
+ def check_cfg(self):
+ unique_name = set()
+ unique_name.add('input')
+ op_list = ppcv.ops.__all__
+ for model in self.model_cfg:
+ model_name = list(model.keys())[0]
+ model_dict = list(model.values())[0]
+ # check the name and last_ops is legal
+ if 'name' not in model_dict:
+ raise ValueError(
+ 'Missing name field in {} model config'.format(model_name))
+ inputs = model_dict['Inputs']
+ for input in inputs:
+ input_str = input.split('.')
+ assert len(
+ input_str
+ ) > 1, 'The Inputs name should be in format of {last_op_name}.{last_op_output_name}, but receive {} in {} model config'.format(
+ input, model_name)
+ last_op = input.split('.')[0]
+ assert last_op in unique_name, 'The last_op {} in {} model config is not exist.'.format(
+ last_op, model_name)
+ unique_name.add(model_dict['name'])
+
+ device = self.env_cfg.get("device", "CPU")
+ assert device.upper() in ['CPU', 'GPU', 'XPU'
+ ], "device should be CPU, GPU or XPU"
+
+ def parse(self):
+ return self.model_cfg, self.env_cfg
+
+ def print_cfg(self):
+ print('----------- Environment Arguments -----------')
+ buffer = yaml.dump(self.env_cfg)
+ print(buffer)
+ print('------------- Model Arguments ---------------')
+ buffer = yaml.dump(self.model_cfg)
+ print(buffer)
+ print('---------------------------------------------')
diff --git a/paddlecv/ppcv/core/framework.py b/paddlecv/ppcv/core/framework.py
new file mode 100644
index 0000000000000000000000000000000000000000..9675ab1b0029a7d61efe7c2456124b16192e94a0
--- /dev/null
+++ b/paddlecv/ppcv/core/framework.py
@@ -0,0 +1,175 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import paddle
+from collections import defaultdict
+import ppcv
+from ppcv.ops import *
+from ppcv.utils.helper import get_output_keys, gen_input_name
+from ppcv.core.workspace import create
+
+
+class DAG(object):
+ """
+ Directed Acyclic Graph(DAG) engine, builds one DAG topology.
+ """
+
+ def __init__(self, cfg):
+ self.graph, self.rev_graph, self.in_degrees = self.build_dag(cfg)
+ self.num = len(self.in_degrees)
+
+ def build_dag(self, cfg):
+ graph = defaultdict(list) # op -> next_op
+ unique_name = set()
+ unique_name.add('input')
+ rev_graph = defaultdict(list) # op -> last_op
+ for op in cfg:
+ op_dict = list(op.values())[0]
+ unique_name.add(op_dict['name'])
+
+ in_degrees = dict((u, 0) for u in unique_name)
+ for op in cfg:
+ op_cfg = list(op.values())[0]
+ inputs = op_cfg['Inputs']
+ for input in inputs:
+ last_op = input.split('.')[0]
+ graph[last_op].append(op_cfg['name'])
+ rev_graph[op_cfg['name']].append(last_op)
+ in_degrees[op_cfg['name']] += 1
+ return graph, rev_graph, in_degrees
+
+ def get_graph(self):
+ return self.graph
+
+ def get_reverse_graph(self):
+ return self.rev_graph
+
+ def topo_sort(self):
+ """
+ Topological sort of DAG, creates inverted multi-layers views.
+ Args:
+ graph (dict): the DAG stucture
+ in_degrees (dict): Next op list for each op
+ Returns:
+ sort_result: the hierarchical topology list. examples:
+ DAG :[A -> B -> C -> E]
+ \-> D /
+ sort_result: [A, B, C, D, E]
+ """
+
+ # Select vertices with in_degree = 0
+ Q = [u for u in self.in_degrees if self.in_degrees[u] == 0]
+ sort_result = []
+ while Q:
+ u = Q.pop()
+ sort_result.append(u)
+ for v in self.graph[u]:
+ # remove output degrees
+ self.in_degrees[v] -= 1
+ # re-select vertices with in_degree = 0
+ if self.in_degrees[v] == 0:
+ Q.append(v)
+ if len(sort_result) == self.num:
+ return sort_result
+ else:
+ return None
+
+
+class Executor(object):
+ """
+ The executor which implements model series pipeline
+
+ Args:
+ env_cfg: The enrionment configuration
+ model_cfg: The models configuration
+ """
+
+ def __init__(self, model_cfg, env_cfg):
+ dag = DAG(model_cfg)
+ self.order = dag.topo_sort()
+ self.model_cfg = model_cfg
+
+ self.op_name2op = {}
+ self.has_output_op = False
+ for op in model_cfg:
+ op_arch = list(op.keys())[0]
+ op_cfg = list(op.values())[0]
+ op_name = op_cfg['name']
+ op = create(op_arch, op_cfg, env_cfg)
+ self.op_name2op[op_name] = op
+ if op.type() == 'OUTPUT':
+ self.has_output_op = True
+
+ self.output_keys = get_output_keys(model_cfg)
+ self.last_ops_dict = dag.get_reverse_graph()
+ self.input_dep = self.reset_dep()
+
+ def reset_dep(self, ):
+ return self.build_dep(self.model_cfg, self.output_keys)
+
+ def build_dep(self, cfg, output_keys):
+ # compute the output degree for each input name
+ dep = dict()
+ for op in cfg:
+ inputs = list(op.values())[0]['Inputs']
+ for name in inputs:
+ if name in dep:
+ dep[name] += 1
+ else:
+ dep.update({name: 1})
+ return dep
+
+ def update_res(self, results, op_outputs, input_name):
+ # step1: remove the result when keys not used in later input
+ for res, out in zip(results, op_outputs):
+ if self.has_output_op:
+ del_name = []
+ for k in out.keys():
+ if k not in self.input_dep:
+ del_name.append(k)
+ # remove the result when keys not used in later input
+ for name in del_name:
+ del out[name]
+ res.update(out)
+
+ # step2: if the input name is no longer used, then result will be deleted
+ if self.has_output_op:
+ for name in input_name:
+ self.input_dep[name] -= 1
+ if self.input_dep[name] == 0:
+ for res in results:
+ del res[name]
+
+ def run(self, input, frame_id=-1):
+ self.input_dep = self.reset_dep()
+ # execute each operator according to toposort order
+ results = input
+ for i, op_name in enumerate(self.order[1:]):
+ op = self.op_name2op[op_name]
+ op.set_frame(frame_id)
+ last_ops = self.last_ops_dict[op_name]
+ input_keys = op.get_input_keys()
+ output_keys = list(results[0].keys())
+ input = op.filter_input(results, input_keys)
+ last_op_output = op(input)
+ if op.type() != 'OUTPUT':
+ op.check_output(last_op_output, op_name)
+ self.update_res(results, last_op_output, input_keys)
+ else:
+ results = last_op_output
+
+ return results
diff --git a/paddlecv/ppcv/core/workspace.py b/paddlecv/ppcv/core/workspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fa41a869ad6bc4a486846baca1057ccfdbc5ee4
--- /dev/null
+++ b/paddlecv/ppcv/core/workspace.py
@@ -0,0 +1,54 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+__all__ = ['register', 'create']
+
+global_config = dict()
+
+
+def register(cls):
+ """
+ Register a given module class.
+ Args:
+ cls (type): Module class to be registered.
+ Returns: cls
+ """
+ if cls.__name__ in global_config:
+ raise ValueError("Module class already registered: {}".format(
+ cls.__name__))
+ global_config[cls.__name__] = cls
+ return cls
+
+
+def create(cls_name, op_cfg, env_cfg):
+ """
+ Create an instance of given module class.
+
+ Args:
+ cls_name(str): Class of which to create instnce.
+
+ Return: instance of type `cls_or_name`
+ """
+ assert type(cls_name) == str, "should be a name of class"
+ if cls_name not in global_config:
+ raise ValueError("The module {} is not registered".format(cls_name))
+
+ cls = global_config[cls_name]
+ return cls(op_cfg, env_cfg)
+
+
+def get_global_op():
+ return global_config
diff --git a/paddlecv/ppcv/engine/__init__.py b/paddlecv/ppcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9186260e3fb083c8bb3cb733f9fc55a79bd3671
--- /dev/null
+++ b/paddlecv/ppcv/engine/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import pipeline
+from .pipeline import *
+
+__all__ = pipeline.__all__
diff --git a/paddlecv/ppcv/engine/pipeline.py b/paddlecv/ppcv/engine/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..35f3bb4c416b2c0b4885617feb6ce2ea6209c8b4
--- /dev/null
+++ b/paddlecv/ppcv/engine/pipeline.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+from collections import defaultdict
+try:
+ from collections.abc import Sequence
+except Exception:
+ from collections import Sequence
+
+from ppcv.core.framework import Executor
+from ppcv.utils.logger import setup_logger
+from ppcv.core.config import ConfigParser
+
+logger = setup_logger('pipeline')
+
+__all__ = ['Pipeline']
+
+
+class Pipeline(object):
+ def __init__(self, cfg):
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+ self.exe = Executor(self.model_cfg, self.env_cfg)
+ self.output_dir = self.env_cfg.get('output_dir', 'output')
+
+ def _parse_input(self, input):
+ if isinstance(input, np.ndarray):
+ return [input], 'data'
+ if isinstance(input, Sequence) and isinstance(input[0], np.ndarray):
+ return input, 'data'
+ im_exts = ['jpg', 'jpeg', 'png', 'bmp']
+ im_exts += [ext.upper() for ext in im_exts]
+ video_exts = ['mp4', 'avi', 'wmv', 'mov', 'mpg', 'mpeg', 'flv']
+ video_exts += [ext.upper() for ext in video_exts]
+
+ if isinstance(input, (list, tuple)) and isinstance(input[0], str):
+ input_type = "image"
+ images = [
+ image for image in input
+ if any([image.endswith(ext) for ext in im_exts])
+ ]
+ return images, input_type
+
+ if os.path.isdir(input):
+ input_type = "image"
+ logger.info(
+ 'Input path is directory, search the images automatically')
+ images = set()
+ infer_dir = os.path.abspath(input)
+ for ext in im_exts:
+ images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
+ images = list(images)
+ return images, input_type
+
+ logger.info('Input path is {}'.format(input))
+ input_ext = os.path.splitext(input)[-1][1:]
+ if input_ext in im_exts:
+ input_type = "image"
+ return [input], input_type
+
+ if input_ext in video_exts:
+ input_type = "video"
+ return input, input_type
+
+ raise ValueError("Unsupported input format: {}".fomat(input_ext))
+ return
+
+ def run(self, input):
+ input, input_type = self._parse_input(input)
+ if input_type == "image" or input_type == 'data':
+ results = self.predict_images(input)
+ elif input_type == "video":
+ results = self.predict_video(input)
+ else:
+ raise ValueError("Unexpected input type: {}".format(input_type))
+ return results
+
+ def decode_image(self, input):
+ if isinstance(input, str):
+ with open(input, 'rb') as f:
+ im_read = f.read()
+ data = np.frombuffer(im_read, dtype='uint8')
+ im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ else:
+ im = input
+ return im
+
+ def predict_images(self, input):
+ batch_input = [{
+ 'input.image': self.decode_image(f),
+ 'input.fn': 'tmp.jpg' if isinstance(f, np.ndarray) else f
+ } for f in input]
+ results = self.exe.run(batch_input)
+ return results
+
+ def predict_video(self, input):
+ capture = cv2.VideoCapture(input)
+ file_name = input.split('/')[-1]
+ # Get Video info : resolution, fps, frame count
+ width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ fps = int(capture.get(cv2.CAP_PROP_FPS))
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
+ logger.info("video fps: %d, frame_count: %d" % (fps, frame_count))
+
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir)
+ out_path = os.path.join(self.output_dir, file_name)
+ fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
+ frame_id = 0
+
+ results = None
+ while (1):
+ if frame_id % 10 == 0:
+ logger.info('frame id: {}'.format(frame_id))
+ ret, frame = capture.read()
+ if not ret:
+ break
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame_input = [{'input.image': frame_rgb, 'input.fn': input}]
+ results = self.exe.run(frame_input, frame_id)
+ writer.write(results[0]['output'])
+ frame_id += 1
+ writer.release()
+ logger.info('save result to {}'.format(out_path))
+ return results
diff --git a/paddlecv/ppcv/model_zoo/MODEL_ZOO b/paddlecv/ppcv/model_zoo/MODEL_ZOO
new file mode 100644
index 0000000000000000000000000000000000000000..04a82d26ad483da825ad0b576885f8a6743c1faa
--- /dev/null
+++ b/paddlecv/ppcv/model_zoo/MODEL_ZOO
@@ -0,0 +1,23 @@
+single_op/PP-YOLOv2
+single_op/PP-PicoDet
+single_op/PP-LiteSeg
+single_op/PP-YOLOE+
+single_op/PP-MattingV1
+single_op/PP-YOLO
+single_op/PP-LCNetV2
+single_op/PP-HGNet
+single_op/PP-LCNet
+single_op/PP-HumanSegV2
+single_op/PP-YOLOE
+system/PP-Structure-layout-table
+system/PP-Structure-re
+system/PP-Structure
+system/PP-OCRv2
+system/PP-Vehicle
+system/PP-ShiTuV2
+system/PP-Structure-table
+system/PP-Human
+system/PP-TinyPose
+system/PP-ShiTu
+system/PP-OCRv3
+system/PP-Structure-ser
diff --git a/paddlecv/ppcv/model_zoo/__init__.py b/paddlecv/ppcv/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c3b37d8eb0f75aa50ca279ca58ea2261f3210b4
--- /dev/null
+++ b/paddlecv/ppcv/model_zoo/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import model_zoo
+from .model_zoo import *
+
+__all__ = model_zoo.__all__
diff --git a/paddlecv/ppcv/model_zoo/model_zoo.py b/paddlecv/ppcv/model_zoo/model_zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a11803e14fb09d4cf08a7c6ed4694e4011a138
--- /dev/null
+++ b/paddlecv/ppcv/model_zoo/model_zoo.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import pkg_resources
+
+try:
+ from collections.abc import Sequence
+except:
+ from collections import Sequence
+
+from ppcv.utils.download import get_config_path, get_model_path
+from ppcv.utils.logger import setup_logger
+logger = setup_logger(__name__)
+
+__all__ = [
+ 'list_model', 'get_config_file', 'get_model_file', 'MODEL_ZOO_FILENAME'
+]
+
+MODEL_ZOO_FILENAME = 'MODEL_ZOO'
+TASK_DICT = {
+ # single model
+ 'classification': 'paddlecv://configs/single_op/PP-HGNet',
+ 'detection': 'paddlecv://configs/single_op/PP-YOLOE+.yml',
+ 'segmentation': 'paddlecv://configs/single_op/PP-LiteSeg.yml',
+ # system
+ 'PP-OCRv2': 'paddlecv://configs/system/PP-OCRv2.yml',
+ 'PP-OCRv3': 'paddlecv://configs/system/PP-OCRv3.yml',
+ 'PP-StructureV2': 'paddlecv://configs/system/PP-Structure.yml',
+ 'PP-StructureV2-layout-table':
+ 'paddlecv://configs/system/PP-Structure-layout-table.yml',
+ 'PP-StructureV2-table': 'paddlecv://configs/system/PP-Structure-table.yml',
+ 'PP-StructureV2-ser': 'paddlecv://configs/system/PP-Structure-ser.yml',
+ 'PP-StructureV2-re': 'paddlecv://configs/system/PP-Structure-re.yml',
+ 'PP-Human': 'paddlecv://configs/system/PP-Human.yml',
+ 'PP-Vehicle': 'paddlecv://configs/system/PP-Vehicle.yml',
+ 'PP-TinyPose': 'paddlecv://configs/system/PP-TinyPose.yml',
+}
+
+
+def list_model(filters=[]):
+ model_zoo_file = pkg_resources.resource_filename('ppcv.model_zoo',
+ MODEL_ZOO_FILENAME)
+ with open(model_zoo_file) as f:
+ model_names = f.read().splitlines()
+
+ # filter model_name
+ def filt(name):
+ for f in filters:
+ if name.find(f) < 0:
+ return False
+ return True
+
+ if isinstance(filters, str) or not isinstance(filters, Sequence):
+ filters = [filters]
+ model_names = [name for name in model_names if filt(name)]
+ if len(model_names) == 0 and len(filters) > 0:
+ raise ValueError("no model found, please check filters seeting, "
+ "filters can be set as following kinds:\n"
+ "\tTask: single_op, system\n"
+ "\tArchitecture: PPLCNet, PPYOLOE ...\n")
+
+ model_str = "Available Models:\n"
+ for model_name in model_names:
+ model_str += "\t{}\n".format(model_name)
+ logger.info(model_str)
+
+
+# models and configs save on bcebos under dygraph directory
+def get_config_file(task):
+ """Get config path from task.
+ """
+ if task not in TASK_DICT:
+ tasks = TASK_DICT.keys()
+ logger.error("Illegal task: {}, please use one of {}".format(task,
+ tasks))
+ path = TASK_DICT[task]
+ return get_config_path(path)
+
+
+def get_model_file(path):
+ return get_model_path(path)
diff --git a/paddlecv/ppcv/ops/__init__.py b/paddlecv/ppcv/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1641e9636dd6944e5d3f7b566e45babee2508cb6
--- /dev/null
+++ b/paddlecv/ppcv/ops/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import models
+from . import output
+from . import connector
+
+from .models import *
+from .output import *
+from .connector import *
+
+__all__ = models.__all__ + output.__all__ + connector.__all__
diff --git a/paddlecv/ppcv/ops/base.py b/paddlecv/ppcv/ops/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f04d7b22fc5d0ebda03d05a2db78f551cfe526a
--- /dev/null
+++ b/paddlecv/ppcv/ops/base.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import importlib
+import math
+import numpy as np
+try:
+ from collections.abc import Sequence
+except Exception:
+ from collections import Sequence
+
+import paddle
+from paddle.inference import Config
+from paddle.inference import create_predictor
+
+from ppcv.ops.predictor import PaddlePredictor
+from ppcv.utils.download import get_model_path
+
+__all__ = ["BaseOp", ]
+
+
+def create_operators(params, mod):
+ """
+ create operators based on the config
+
+ Args:
+ params(list): a dict list, used to create some operators
+ mod(module) : a module that can import single ops
+ """
+ assert isinstance(params, list), ('operator config should be a list')
+ if mod is None:
+ mod = importlib.import_module(__name__)
+ ops = []
+ for operator in params:
+ if isinstance(operator, str):
+ op_name = operator
+ param = {}
+ else:
+ assert isinstance(operator,
+ dict) and len(operator) == 1, "yaml format error"
+ op_name = list(operator)[0]
+ param = {} if operator[op_name] is None else operator[op_name]
+
+ op = getattr(mod, op_name)(**param)
+ ops.append(op)
+
+ return ops
+
+
+class BaseOp(object):
+ """
+ Base Operator, implement of prediction process
+ Args
+ """
+
+ def __init__(self, model_cfg, env_cfg):
+ self.model_cfg = model_cfg
+ self.env_cfg = env_cfg
+ self.input_keys = model_cfg["Inputs"]
+
+ @classmethod
+ def type(self):
+ raise NotImplementedError
+
+ @classmethod
+ def get_output_keys(cls):
+ raise NotImplementedError
+
+ def get_input_keys(self):
+ return self.input_keys
+
+ def filter_input(self, last_outputs, input_keys):
+ f_inputs = [{k: last[k] for k in input_keys} for last in last_outputs]
+ return f_inputs
+
+ def check_output(self, output, name):
+ if not isinstance(output, Sequence):
+ raise ValueError('The output of op: {} must be Sequence').format(
+ name)
+ output = output[0]
+ if not isinstance(output, dict):
+ raise ValueError(
+ 'The element of output in op: {} must be dict').format(name)
+ out_keys = list(output.keys())
+ for out, define in zip(out_keys, self.output_keys):
+ if out != define:
+ raise ValueError(
+ 'The output key in op: {} is inconsistent, expect {}, but received {}'.
+ format(name, define, out))
+
+ def set_frame(self, frame_id):
+ self.frame_id = frame_id
+
+ def __call__(self, image_list):
+ raise NotImplementedError
diff --git a/paddlecv/ppcv/ops/connector/__init__.py b/paddlecv/ppcv/ops/connector/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfbc5e1ff1d9fa3e334a2cef3bf09d47486b7fde
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .op_connector import *
+
+__all__ = op_connector.__all__
diff --git a/paddlecv/ppcv/ops/connector/base.py b/paddlecv/ppcv/ops/connector/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d315823ec24a76a0e34664c97122662ff637792
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+
+from ppcv.ops.base import BaseOp
+
+
+class ConnectorBaseOp(BaseOp):
+ def __init__(self, model_cfg, env_cfg=None):
+ super(ConnectorBaseOp, self).__init__(model_cfg, env_cfg)
+ self.name = model_cfg["name"]
+ keys = self.get_output_keys()
+ self.output_keys = [self.name + '.' + key for key in keys]
+
+ @classmethod
+ def type(self):
+ return 'CONNECTOR'
diff --git a/paddlecv/ppcv/ops/connector/keyframes_extract_helper.py b/paddlecv/ppcv/ops/connector/keyframes_extract_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..193f8a030a386196a7239c1d3d3c5a8dd952d33c
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/keyframes_extract_helper.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# code reference: https://github.com/huangjun12/KeyFramesExtraction/blob/master/scene_div.py
+
+import cv2
+import operator
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import sys
+from scipy.signal import argrelextrema
+
+
+def smooth(x, window_len=13, window='hanning'):
+ s = np.r_[2 * x[0] - x[window_len:1:-1], x, 2 * x[-1] - x[-1:-window_len:
+ -1]]
+
+ if window == 'flat': # moving average
+ w = np.ones(window_len, 'd')
+ else:
+ w = getattr(np, window)(window_len)
+ y = np.convolve(w / w.sum(), s, mode='same')
+ return y[window_len - 1:-window_len + 1]
+
+
+class Frame:
+ """class to hold information about each frame
+ """
+
+ def __init__(self, id, diff):
+ self.id = id
+ self.diff = diff
+
+ def __lt__(self, other):
+ if self.id == other.id:
+ return self.id < other.id
+ return self.id < other.id
+
+ def __gt__(self, other):
+ return other.__lt__(self)
+
+ def __eq__(self, other):
+ return self.id == other.id and self.id == other.id
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+def rel_change(a, b):
+ x = (b - a) / max(a, b)
+ return x
+
+
+class KeyFrameExtractor(object):
+ def __init__(self, config):
+ pass
+
+ def extract_by_video_path(self, video_path):
+ raise NotImplementedError
+
+ def exrtact_by_image_list(self, image_list):
+ raise NotImplementedError
+
+ def __call__(self, video_obj):
+ assert isinstance(video_obj, (list, tuple, str))
+ if isinstance(video_obj, str):
+ output = self.extract_by_video_path(video_obj)
+ elif isinstance(video_obj, (list, tuple)):
+ output = self.exrtact_by_image_list(video_obj)
+ return output
+
+
+class LUVAbsDiffKeyFrameExtractor(KeyFrameExtractor):
+ """
+ extract key frames based on sum of absolute differences in LUV colorspace.
+ """
+
+ def __init__(self, config):
+ self.thresh = config.get("thresh", None)
+ self.use_top_order = config.get("use_top_order", False)
+ self.use_local_maxima = config.get("use_local_maxima", None)
+ self.num_top_frames = config.get("num_top_frames", None)
+ self.window_len = config.get("window_len", None)
+
+ def extract_by_video_path(self, video_path):
+ cap = cv2.VideoCapture(video_path)
+ curr_frame = None
+ prev_frame = None
+ frame_diffs = []
+ frames = []
+ success, frame = cap.read()
+ i = 0
+ while (success):
+ luv = cv2.cvtColor(frame, cv2.COLOR_BGR2LUV)
+ curr_frame = luv
+ if curr_frame is not None and prev_frame is not None:
+ diff = cv2.absdiff(curr_frame, prev_frame)
+ diff_sum = np.sum(diff)
+ diff_sum_mean = diff_sum / (diff.shape[0] * diff.shape[1])
+ frame_diffs.append(diff_sum_mean)
+ frame = Frame(i, diff_sum_mean)
+ frames.append(frame)
+ prev_frame = curr_frame
+ i = i + 1
+ success, frame = cap.read()
+ cap.release()
+
+ # compute keyframe
+ keyframe_id_set = set()
+ if self.use_top_order:
+ # sort the list in descending order
+ frames.sort(key=operator.attrgetter("diff"), reverse=True)
+ for keyframe in frames[:self.num_top_frames]:
+ keyframe_id_set.add(keyframe.id)
+ if self.thresh is not None:
+ for i in range(1, len(frames)):
+ if (rel_change(
+ np.float(frames[i - 1].diff), np.float(frames[i].diff))
+ >= self.thresh):
+ keyframe_id_set.add(frames[i].id)
+ if self.use_local_maxima:
+ diff_array = np.array(frame_diffs)
+ sm_diff_array = smooth(diff_array, self.window_len)
+ frame_indexes = np.asarray(
+ argrelextrema(sm_diff_array, np.greater))[0]
+ for i in frame_indexes:
+ keyframe_id_set.add(frames[i - 1].id)
+
+ keyframe_id_set = sorted(list(keyframe_id_set))
+ # save all keyframes as image
+ cap = cv2.VideoCapture(str(video_path))
+ curr_frame = None
+ keyframes = []
+ success, frame = cap.read()
+ idx = 0
+ while (success):
+ if idx in keyframe_id_set:
+ keyframes.append(frame)
+ idx = idx + 1
+ success, frame = cap.read()
+ cap.release()
+ return keyframes, keyframe_id_set
diff --git a/paddlecv/ppcv/ops/connector/op_connector.py b/paddlecv/ppcv/ops/connector/op_connector.py
new file mode 100644
index 0000000000000000000000000000000000000000..c639564a7b4021c2aab3ff3c90c08faf5118d26a
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/op_connector.py
@@ -0,0 +1,623 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = [
+ "ClsCorrectionOp", "BboxCropOp", "PolyCropOp", "FragmentCompositionOp",
+ "KeyFrameExtractionOp", "TableMatcherOp", 'TrackerOP', 'BboxExpandCropOp'
+]
+
+import cv2
+import numpy as np
+import importlib
+from collections import defaultdict
+
+from ppcv.core.workspace import register
+from ppcv.ops.base import create_operators
+from .base import ConnectorBaseOp
+from .keyframes_extract_helper import LUVAbsDiffKeyFrameExtractor
+from .table_matcher import TableMatcher
+from .tracker import OCSORTTracker, ParserTrackerResults
+
+
+@register
+class ClsCorrectionOp(ConnectorBaseOp):
+ """
+ rotate
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.class_num = model_cfg["class_num"]
+ assert self.class_num in [
+ 2, 4
+ ], f"just [2, 4] are supported but got {self.class_num}"
+ if self.class_num == 2:
+ self.rotate_code = {1: cv2.ROTATE_180, }
+ else:
+ self.rotate_code = {
+ 1: cv2.ROTATE_90_COUNTERCLOCKWISE,
+ 2: cv2.ROTATE_180,
+ 3: cv2.ROTATE_90_CLOCKWISE,
+ }
+
+ self.threshold = model_cfg["threshold"]
+ self.check_input_keys()
+ return
+
+ @classmethod
+ def get_output_keys(self):
+ return ["corr_image"]
+
+ def check_input_keys(self, ):
+ # image, cls_id, prob is needed.
+ assert len(
+ self.input_keys
+ ) == 3, f"input key of {self} must be 3 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ images = inputs[idx][self.input_keys[0]]
+ cls_ids = inputs[idx][self.input_keys[1]]
+ probs = inputs[idx][self.input_keys[2]]
+ is_image_list = isinstance(images, (list, tuple))
+ if is_image_list is not True:
+ images = [images]
+ cls_ids = [cls_ids]
+ probs = [probs]
+ output = []
+ for image, cls_id, prob in zip(images, cls_ids, probs):
+ cls_id = cls_id[0]
+ prob = prob[0]
+ corr_image = image.copy()
+ if prob >= self.threshold and cls_id in self.rotate_code:
+ corr_image = cv2.rotate(corr_image,
+ self.rotate_code[cls_id])
+ output.append(corr_image)
+
+ if is_image_list is not True:
+ output = output[0]
+ outputs.append(output)
+ return outputs
+
+
+@register
+class BboxCropOp(ConnectorBaseOp):
+ """
+ BboxCropOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.check_input_keys()
+ return
+
+ @classmethod
+ def get_output_keys(self):
+ return ["crop_image"]
+
+ def check_input_keys(self, ):
+ # image, bbox is needed.
+ assert len(
+ self.input_keys
+ ) == 2, f"input key of {self} must be 2 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ images = inputs[idx][self.input_keys[0]]
+ bboxes = inputs[idx][self.input_keys[1]]
+ is_image_list = isinstance(images, (list, tuple))
+ if is_image_list is not True:
+ images = [images]
+ bboxes = [bboxes]
+ output = []
+ # bbox: N x 4, x1, y1, x2, y2
+ for image, bbox, in zip(images, bboxes):
+ crop_imgs = []
+ for single_bbox in np.array(bbox):
+ xmin, ymin, xmax, ymax = single_bbox.astype("int")
+ crop_img = image[ymin:ymax, xmin:xmax, :].copy()
+ crop_imgs.append(crop_img)
+ output.append(crop_imgs)
+
+ if is_image_list is not True:
+ output = output[0]
+ outputs.append({self.output_keys[0]: output})
+ return outputs
+
+
+@register
+class PolyCropOp(ConnectorBaseOp):
+ """
+ PolyCropOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.check_input_keys()
+ return
+
+ @classmethod
+ def get_output_keys(self):
+ return ["crop_image"]
+
+ def check_input_keys(self, ):
+ # image, bbox is needed.
+ assert len(
+ self.input_keys
+ ) == 2, f"input key of {self} must be 2 but got {len(self.input_keys)}"
+
+ def get_rotate_crop_image(self, img, points):
+ '''
+ img_height, img_width = img.shape[0:2]
+ left = int(np.min(points[:, 0]))
+ right = int(np.max(points[:, 0]))
+ top = int(np.min(points[:, 1]))
+ bottom = int(np.max(points[:, 1]))
+ img_crop = img[top:bottom, left:right, :].copy()
+ points[:, 0] = points[:, 0] - left
+ points[:, 1] = points[:, 1] - top
+ '''
+ assert len(points) == 4, "shape of points must be 4*2"
+ img_crop_width = int(
+ max(
+ np.linalg.norm(points[0] - points[1]),
+ np.linalg.norm(points[2] - points[3])))
+ img_crop_height = int(
+ max(
+ np.linalg.norm(points[0] - points[3]),
+ np.linalg.norm(points[1] - points[2])))
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height]])
+ M = cv2.getPerspectiveTransform(points.astype(np.float32), pts_std)
+ dst_img = cv2.warpPerspective(
+ img,
+ M, (img_crop_width, img_crop_height),
+ borderMode=cv2.BORDER_REPLICATE,
+ flags=cv2.INTER_CUBIC)
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
+ dst_img = np.rot90(dst_img)
+ return dst_img
+
+ def sorted_boxes(self, dt_boxes):
+ """
+ Sort text boxes in order from top to bottom, left to right
+ args:
+ dt_boxes(array):detected text boxes with shape [4, 2]
+ return:
+ sorted boxes(array) with shape [4, 2]
+ """
+ num_boxes = dt_boxes.shape[0]
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+ _boxes = list(sorted_boxes)
+
+ for i in range(num_boxes - 1):
+ for j in range(i, 0, -1):
+ if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
+ (_boxes[j + 1][0][0] < _boxes[j][0][0]):
+ tmp = _boxes[j]
+ _boxes[j] = _boxes[j + 1]
+ _boxes[j + 1] = tmp
+ else:
+ break
+ return _boxes
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ images = inputs[idx][self.input_keys[0]]
+ polys = inputs[idx][self.input_keys[1]]
+ is_image_list = isinstance(images, (list, tuple))
+ if is_image_list is not True:
+ images = [images]
+ polys = [polys]
+ output = []
+ # bbox: N x 4 x 2, x1,y1, x2,y2, x3,y3, x4,y4
+ for image, poly, in zip(images, polys):
+ crop_imgs = []
+ for single_poly in poly:
+ crop_img = self.get_rotate_crop_image(image, single_poly)
+ crop_imgs.append(crop_img)
+ output.append(crop_imgs)
+
+ if is_image_list is not True:
+ output = output[0]
+
+ outputs.append({self.output_keys[0]: output, })
+ return outputs
+
+
+@register
+class FragmentCompositionOp(ConnectorBaseOp):
+ """
+ FragmentCompositionOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.split = model_cfg.get("split", " ")
+ self.check_input_keys()
+ return
+
+ @classmethod
+ def get_output_keys(self):
+ return ["merged_text"]
+
+ def check_input_keys(self, ):
+ # list of string is needed
+ assert len(
+ self.input_keys
+ ) == 1, f"input key of {self} must be 1 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ strs = inputs[idx][self.input_keys[0]]
+ output = self.split.join(strs)
+ outputs.append(output)
+ return outputs
+
+
+@register
+class KeyFrameExtractionOp(ConnectorBaseOp):
+ """
+ KeyFrameExtractionOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.check_input_keys()
+ assert model_cfg["algo"] in ["luv_diff", ]
+ if model_cfg["algo"] == "luv_diff":
+ self.extractor = LUVAbsDiffKeyFrameExtractor(model_cfg["params"])
+
+ @classmethod
+ def get_output_keys(self):
+ return ["key_frames", "key_frames_id"]
+
+ def check_input_keys(self, ):
+ # video is needed
+ assert len(
+ self.input_keys
+ ) == 1, f"input key of {self} must be 1 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ input = inputs[idx][self.input_keys[0]]
+ key_frames, key_frames_id = self.extractor(input)
+ outputs.append([key_frames, key_frames_id])
+ return outputs
+
+
+@register
+class TableMatcherOp(ConnectorBaseOp):
+ """
+ TableMatcherOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.check_input_keys()
+ filter_ocr_result = model_cfg.get("filter_ocr_result", False)
+ self.matcher = TableMatcher(filter_ocr_result=filter_ocr_result)
+
+ @classmethod
+ def get_output_keys(self):
+ return ["html"]
+
+ def check_input_keys(self, ):
+ # pred_structure, pred_bboxes, dt_boxes, res_res are needed
+ assert len(
+ self.input_keys
+ ) == 4, f"input key of {self} must be 4 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ structure_bboxes = inputs[idx][self.input_keys[0]]
+ structure_strs = inputs[idx][self.input_keys[1]]
+ dt_boxes = inputs[idx][self.input_keys[2]]
+ rec_res = inputs[idx][self.input_keys[3]]
+
+ if len(structure_strs) == 0:
+ outputs.append({self.output_keys[0]: ['']})
+ continue
+ is_list = isinstance(structure_strs[0], (list, tuple))
+ if is_list is not True:
+ structure_strs = [structure_strs]
+ structure_bboxes = [structure_bboxes]
+ dt_boxes = [dt_boxes]
+ rec_res = [rec_res]
+
+ output = []
+ for single_structure_strs, single_structure_bboxes, single_dt_boxes, single_rec_res, in zip(
+ structure_strs, structure_bboxes, dt_boxes, rec_res):
+ pred_html = self.matcher(single_structure_strs,
+ np.array(single_structure_bboxes),
+ single_dt_boxes.reshape([-1, 8]),
+ single_rec_res)
+ pred_html = ''
+ output.append({self.output_keys[0]: pred_html})
+ if is_list is not True:
+ output = output[0]
+ else:
+ d = defaultdict(list)
+ for item in output:
+ for k in self.output_keys:
+ d[k].append(item[k])
+ output = d
+ outputs.append(output)
+ return outputs
+
+
+@register
+class PPStructureFilterOp(ConnectorBaseOp):
+ """
+ PPStructureFilterOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.keep_keys = model_cfg.get("keep_keys", [])
+ self.check_input_keys()
+ return
+
+ @classmethod
+ def get_output_keys(self):
+ return ["image", "dt_polys", "rec_text"]
+
+ def check_input_keys(self, ):
+ # list of string is needed
+ assert len(
+ self.input_keys
+ ) == 4, f"input key of {self} must be 4 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx, input in enumerate(inputs):
+ images, dt_polys, rec_text = [], [], []
+ for i in range(len(input[self.input_keys[0]])):
+ if input[self.input_keys[0]][i] in self.keep_keys:
+ images.append(input[self.input_keys[1]][i])
+ dt_polys.append(input[self.input_keys[2]][i])
+ rec_text.append(input[self.input_keys[3]][i])
+ outputs.append({
+ self.output_keys[0]: images,
+ self.output_keys[1]: dt_polys,
+ self.output_keys[2]: rec_text,
+ })
+ return outputs
+
+
+@register
+class PPStructureResultConcatOp(ConnectorBaseOp):
+ """
+ PPStructureResultConcatOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.keep_keys = model_cfg.get("keep_keys", [])
+ self.check_input_keys()
+ return
+
+ @classmethod
+ def get_output_keys(self):
+ return [
+ "dt_polys", "rec_text", "dt_bboxes", "html", "cell_bbox",
+ "structures"
+ ]
+
+ def check_input_keys(self, ):
+ # list of string is needed
+ assert len(
+ self.input_keys
+ ) == 8, f"input key of {self} must be 8 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx, input in enumerate(inputs):
+ dt_polys, rec_text = [], []
+ structures, html, layout_dt_bboxes, table_dt_bboxes, table_dt_polys, table_rec_text, txts_dt_polys, txts_rec_text = input.values(
+ )
+ dt_polys.extend(txts_dt_polys)
+ rec_text.extend(txts_rec_text)
+ dt_polys.extend(table_dt_polys)
+ rec_text.extend(table_rec_text)
+ input_html = [''] * len(txts_dt_polys) + html
+ input_structures = [[] for _ in range(len(txts_dt_polys))
+ ] + structures
+ cell_bbox = [[]
+ for _ in range(len(txts_dt_polys))] + table_dt_bboxes
+ outputs.append({
+ self.output_keys[0]: dt_polys,
+ self.output_keys[1]: rec_text,
+ self.output_keys[2]: layout_dt_bboxes,
+ self.output_keys[3]: input_html,
+ self.output_keys[4]: cell_bbox,
+ self.output_keys[5]: input_structures,
+ })
+ return outputs
+
+
+@register
+class OCRRotateOp(ConnectorBaseOp):
+ """
+ OCRRotateOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.thresh = model_cfg.get("thresh", 0)
+ self.cv_rotate_code = model_cfg.get('rotate_map', {
+ '90': cv2.ROTATE_90_COUNTERCLOCKWISE,
+ '180': cv2.ROTATE_180,
+ '270': cv2.ROTATE_90_CLOCKWISE
+ })
+ self.check_input_keys()
+
+ @classmethod
+ def get_output_keys(self):
+ return ["image"]
+
+ def check_input_keys(self, ):
+ # list of string is needed
+ assert len(
+ self.input_keys
+ ) == 3, f"input key of {self} must be 3 but got {len(self.input_keys)}"
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx, input in enumerate(inputs):
+ image = input[self.input_keys[0]]
+ label_name = input[self.input_keys[1]][0]
+ score = input[self.input_keys[2]][0]
+ if score > self.thresh and label_name in self.cv_rotate_code:
+ image = cv2.rotate(image, self.cv_rotate_code[label_name])
+ outputs.append({self.output_keys[0]: image, })
+ return outputs
+
+
+@register
+class TrackerOP(ConnectorBaseOp):
+ """
+ tracker
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.tracker_type = model_cfg['type']
+ assert self.tracker_type in ['OCSORTTracker'
+ ], f"Only OCSORTTracker is supported now"
+ tracker_kwargs = model_cfg['tracker_configs']
+ self.tracker = eval(self.tracker_type)(**tracker_kwargs)
+ self.check_input_keys()
+ mod = importlib.import_module(__name__)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+
+ @classmethod
+ def get_output_keys(self):
+ return [
+ "tk_bboxes", "tk_scores", "tk_ids", "tk_cls_ids", "tk_cls_names"
+ ]
+
+ def check_input_keys(self):
+ # "dt_bboxes", "dt_scores", "dt_class_ids" or plus reid feature
+ assert len(self.input_keys) in [
+ 3, 4
+ ], 'for OCSORTTracker, now only supported det ouputs and reid outputs'
+
+ def create_inputs(self, det_result):
+ dt_bboxes = np.array(det_result[self.input_keys[0]])
+ dt_scores = np.array(det_result[self.input_keys[1]])
+ dt_class_ids = np.array(det_result[self.input_keys[2]])
+ dt_preds = np.concatenate(
+ [dt_class_ids[:, None], dt_scores[:, None], dt_bboxes], axis=-1)
+ if len(self.input_keys) > 3:
+ dt_embs = np.array(det_result[self.input_keys[3]])
+ else:
+ dt_embs = None
+
+ return dt_preds, dt_embs
+
+ def postprocess(self, outputs):
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ output_keys = ops(outputs, self.output_keys)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def __call__(self, inputs):
+ pipe_outputs = []
+ for input in inputs:
+ dt_preds, dt_embs = self.create_inputs(input)
+ tracking_outs = self.tracker.tracking(dt_preds, dt_embs,
+ self.output_keys)
+ tracking_outs = self.postprocess(tracking_outs)
+ pipe_outputs.append(tracking_outs)
+
+ return pipe_outputs
+
+
+@register
+class BboxExpandCropOp(ConnectorBaseOp):
+ """
+ BboxExpandCropOp
+ """
+
+ def __init__(self, model_cfg, env_cfg=None):
+ super().__init__(model_cfg, env_cfg)
+ self.expand_ratio = model_cfg.get('expand_ratio', 0.3)
+ self.check_input_keys()
+
+ @classmethod
+ def get_output_keys(self):
+ return ['crop_image', 'tl_point']
+
+ def check_input_keys(self, ):
+ # image, bbox is needed.
+ assert len(
+ self.input_keys
+ ) == 2, f"input key of {self} must be 2 but got {len(self.input_keys)}"
+
+ def expand_crop(self, image, box):
+ imgh, imgw, c = image.shape
+ xmin, ymin, xmax, ymax = [int(x) for x in box]
+ h_half = (ymax - ymin) * (1 + self.expand_ratio) / 2.
+ w_half = (xmax - xmin) * (1 + self.expand_ratio) / 2.
+ if h_half > w_half * 4 / 3:
+ w_half = h_half * 0.75
+ center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
+ ymin = max(0, int(center[0] - h_half))
+ ymax = min(imgh - 1, int(center[0] + h_half))
+ xmin = max(0, int(center[1] - w_half))
+ xmax = min(imgw - 1, int(center[1] + w_half))
+ return image[ymin:ymax, xmin:xmax, :], [xmin, ymin]
+
+ def __call__(self, inputs):
+ outputs = []
+ for idx in range(len(inputs)):
+ images = inputs[idx][self.input_keys[0]]
+ bboxes = inputs[idx][self.input_keys[1]]
+ is_image_list = isinstance(images, (list, tuple))
+ if is_image_list is False:
+ images = [images]
+ bboxes = [bboxes]
+
+ output_images = []
+ output_points = []
+ # bbox: N x 4, x1, y1, x2, y2
+ for image, bbox, in zip(images, bboxes):
+ crop_imgs = []
+ tl_points = []
+ for single_bbox in bbox:
+ crop_img, tl_point = self.expand_crop(image, single_bbox)
+ crop_imgs.append(crop_img)
+ tl_points.append(tl_point)
+ output_images.append(crop_imgs)
+ output_points.append(tl_points)
+
+ if is_image_list is False:
+ output_images = output_images[0]
+ output_points = output_points[0]
+ outputs.append({
+ self.output_keys[0]: output_images,
+ self.output_keys[1]: output_points
+ })
+ return outputs
diff --git a/paddlecv/ppcv/ops/connector/table_matcher.py b/paddlecv/ppcv/ops/connector/table_matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..de764139fa218725cca9972e33d3c7c10f9520a9
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/table_matcher.py
@@ -0,0 +1,151 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# code reference: https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppstructure/table/matcher.py
+
+import numpy as np
+
+
+def distance(box_1, box_2):
+ x1, y1, x2, y2 = box_1
+ x3, y3, x4, y4 = box_2
+ dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+ dis_2 = abs(x3 - x1) + abs(y3 - y1)
+ dis_3 = abs(x4 - x2) + abs(y4 - y2)
+ return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+ """
+ computing IoU
+ :param rec1: (y0, x0, y1, x1), which reflects
+ (top, left, bottom, right)
+ :param rec2: (y0, x0, y1, x1)
+ :return: scala value of IoU
+ """
+ # computing area of each rectangles
+ S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+ S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+ # computing the sum_area
+ sum_area = S_rec1 + S_rec2
+
+ # find the each edge of intersect rectangle
+ left_line = max(rec1[1], rec2[1])
+ right_line = min(rec1[3], rec2[3])
+ top_line = max(rec1[0], rec2[0])
+ bottom_line = min(rec1[2], rec2[2])
+
+ # judge if there is an intersect
+ if left_line >= right_line or top_line >= bottom_line:
+ return 0.0
+ else:
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect)) * 1.0
+
+
+class TableMatcher:
+ def __init__(self, filter_ocr_result=False):
+ self.filter_ocr_result = filter_ocr_result
+
+ def __call__(self, structure_strs, structure_bboxes, dt_boxes, rec_res):
+ if self.filter_ocr_result:
+ dt_boxes, rec_res = self._filter_ocr_result(structure_bboxes,
+ dt_boxes, rec_res)
+ matched_index = self.match_result(dt_boxes, structure_bboxes)
+
+ pred_html, pred = self.get_pred_html(structure_strs, matched_index,
+ rec_res)
+ return pred_html
+
+ def match_result(self, dt_boxes, structure_bboxes):
+ matched = {}
+ for i, gt_box in enumerate(dt_boxes):
+ distances = []
+ gt_box = [
+ np.min(gt_box[0::2]), np.min(gt_box[1::2]),
+ np.max(gt_box[0::2]), np.max(gt_box[1::2])
+ ]
+ for j, pred_box in enumerate(structure_bboxes):
+ if len(pred_box) == 8:
+ pred_box = [
+ np.min(pred_box[0::2]), np.min(pred_box[1::2]),
+ np.max(pred_box[0::2]), np.max(pred_box[1::2])
+ ]
+ distances.append((distance(gt_box, pred_box),
+ 1. - compute_iou(gt_box, pred_box)
+ )) # compute iou and l1 distance
+ sorted_distances = distances.copy()
+ # select det box by iou and l1 distance
+ sorted_distances = sorted(
+ sorted_distances, key=lambda item: (item[1], item[0]))
+ if distances.index(sorted_distances[0]) not in matched.keys():
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched
+
+ def get_pred_html(self, structure_strs, matched_index, ocr_contents):
+ end_html = []
+ td_index = 0
+ for tag in structure_strs:
+ if '' in tag:
+ if ' | ' == tag:
+ end_html.extend('')
+ if td_index in matched_index.keys():
+ b_with = False
+ if '' in ocr_contents[matched_index[td_index][
+ 0]] and len(matched_index[td_index]) > 1:
+ b_with = True
+ end_html.extend('')
+ for i, td_index_index in enumerate(matched_index[
+ td_index]):
+ content = ocr_contents[td_index_index]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+ if content[0] == ' ':
+ content = content[1:]
+ if '' in content:
+ content = content[3:]
+ if '' in content:
+ content = content[:-4]
+ if len(content) == 0:
+ continue
+ if i != len(matched_index[
+ td_index]) - 1 and ' ' != content[-1]:
+ content += ' '
+ end_html.extend(content)
+ if b_with:
+ end_html.extend('')
+ if ' | | ' == tag:
+ end_html.append('')
+ else:
+ end_html.append(tag)
+ td_index += 1
+ else:
+ end_html.append(tag)
+ return ''.join(end_html), end_html
+
+ def _filter_ocr_result(self, structure_bboxes, dt_boxes, rec_res):
+ y1 = structure_bboxes[:, 1::2].min()
+ new_dt_boxes = []
+ new_rec_res = []
+
+ for box, rec in zip(dt_boxes, rec_res):
+ if np.max(box[1::2]) < y1:
+ continue
+ new_dt_boxes.append(box)
+ new_rec_res.append(rec)
+ return new_dt_boxes, new_rec_res
\ No newline at end of file
diff --git a/paddlecv/ppcv/ops/connector/tracker/__init__.py b/paddlecv/ppcv/ops/connector/tracker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb96d46e5e635161a803c6702386b063b6017be6
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/tracker/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .tracker import *
+from .postprocess import *
+
+__all__ = tracker.__all__ + postprocess.__all__
\ No newline at end of file
diff --git a/paddlecv/ppcv/ops/connector/tracker/matching/__init__.py b/paddlecv/ppcv/ops/connector/tracker/matching/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e36a3dbc4a28f81601d195c23dcc4c8797e52366
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/tracker/matching/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import ocsort_matching
+
+from .ocsort_matching import *
diff --git a/paddlecv/ppcv/ops/connector/tracker/matching/ocsort_matching.py b/paddlecv/ppcv/ops/connector/tracker/matching/ocsort_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd286b1f612efa86f1efab80e9c96723bd7a112
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/tracker/matching/ocsort_matching.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/association.py
+"""
+
+import os
+import numpy as np
+
+
+def iou_batch(bboxes1, bboxes2):
+ bboxes2 = np.expand_dims(bboxes2, 0)
+ bboxes1 = np.expand_dims(bboxes1, 1)
+
+ xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
+ yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
+ xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
+ yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
+ w = np.maximum(0., xx2 - xx1)
+ h = np.maximum(0., yy2 - yy1)
+ area = w * h
+ iou_matrix = area / ((bboxes1[..., 2] - bboxes1[..., 0]) *
+ (bboxes1[..., 3] - bboxes1[..., 1]) +
+ (bboxes2[..., 2] - bboxes2[..., 0]) *
+ (bboxes2[..., 3] - bboxes2[..., 1]) - area)
+ return iou_matrix
+
+
+def speed_direction_batch(dets, tracks):
+ tracks = tracks[..., np.newaxis]
+ CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
+ CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (
+ tracks[:, 1] + tracks[:, 3]) / 2.0
+ dx = CX1 - CX2
+ dy = CY1 - CY2
+ norm = np.sqrt(dx**2 + dy**2) + 1e-6
+ dx = dx / norm
+ dy = dy / norm
+ return dy, dx
+
+
+def linear_assignment(cost_matrix):
+ try:
+ import lap
+ _, x, y = lap.lapjv(cost_matrix, extend_cost=True)
+ return np.array([[y[i], i] for i in x if i >= 0])
+ except ImportError:
+ from scipy.optimize import linear_sum_assignment
+ x, y = linear_sum_assignment(cost_matrix)
+ return np.array(list(zip(x, y)))
+
+
+def associate(detections, trackers, iou_threshold, velocities, previous_obs,
+ vdc_weight):
+ if (len(trackers) == 0):
+ return np.empty(
+ (0, 2), dtype=int), np.arange(len(detections)), np.empty(
+ (0, 5), dtype=int)
+
+ Y, X = speed_direction_batch(detections, previous_obs)
+ inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
+ inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
+ inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
+ diff_angle_cos = inertia_X * X + inertia_Y * Y
+ diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
+ diff_angle = np.arccos(diff_angle_cos)
+ diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
+
+ valid_mask = np.ones(previous_obs.shape[0])
+ valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
+
+ iou_matrix = iou_batch(detections, trackers)
+ scores = np.repeat(
+ detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
+ # iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
+ valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
+
+ angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
+ angle_diff_cost = angle_diff_cost.T
+ angle_diff_cost = angle_diff_cost * scores
+
+ if min(iou_matrix.shape) > 0:
+ a = (iou_matrix > iou_threshold).astype(np.int32)
+ if a.sum(1).max() == 1 and a.sum(0).max() == 1:
+ matched_indices = np.stack(np.where(a), axis=1)
+ else:
+ matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost
+ ))
+ else:
+ matched_indices = np.empty(shape=(0, 2))
+
+ unmatched_detections = []
+ for d, det in enumerate(detections):
+ if (d not in matched_indices[:, 0]):
+ unmatched_detections.append(d)
+ unmatched_trackers = []
+ for t, trk in enumerate(trackers):
+ if (t not in matched_indices[:, 1]):
+ unmatched_trackers.append(t)
+
+ # filter out matched with low IOU
+ matches = []
+ for m in matched_indices:
+ if (iou_matrix[m[0], m[1]] < iou_threshold):
+ unmatched_detections.append(m[0])
+ unmatched_trackers.append(m[1])
+ else:
+ matches.append(m.reshape(1, 2))
+ if (len(matches) == 0):
+ matches = np.empty((0, 2), dtype=int)
+ else:
+ matches = np.concatenate(matches, axis=0)
+
+ return matches, np.array(unmatched_detections), np.array(
+ unmatched_trackers)
diff --git a/paddlecv/ppcv/ops/connector/tracker/postprocess.py b/paddlecv/ppcv/ops/connector/tracker/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..61d624bf83c6103b049e404b6433b7aa60d54c7a
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/tracker/postprocess.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+from scipy.special import softmax
+
+from ppcv.utils.download import get_dict_path
+
+__all__ = ['ParserTrackerResults']
+
+
+class ParserTrackerResults(object):
+ def __init__(self, label_list):
+ self.clsid2catid, self.catid2name = self.get_categories(label_list)
+
+ def get_categories(self, label_list):
+ if isinstance(label_list, list):
+ clsid2catid = {i: i for i in range(len(label_list))}
+ catid2name = {i: label_list[i] for i in range(len(label_list))}
+ return clsid2catid, catid2name
+
+ label_list = get_dict_path(label_list)
+ if label_list.endswith('json'):
+ # lazy import pycocotools here
+ from pycocotools.coco import COCO
+ coco = COCO(label_list)
+ cats = coco.loadCats(coco.getCatIds())
+ clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
+ catid2name = {cat['id']: cat['name'] for cat in cats}
+ elif label_list.endswith('txt'):
+ cats = []
+ with open(label_list) as f:
+ for line in f.readlines():
+ cats.append(line.strip())
+ if cats[0] == 'background': cats = cats[1:]
+
+ clsid2catid = {i: i for i in range(len(cats))}
+ catid2name = {i: name for i, name in enumerate(cats)}
+
+ else:
+ raise ValueError("label_list {} should be json or txt.".format(
+ label_list))
+ return clsid2catid, catid2name
+
+ def __call__(self, tracking_outputs, output_keys):
+ tk_cls_ids = tracking_outputs[output_keys[3]]
+ tk_cls_names = [
+ self.catid2name[self.clsid2catid[cls_id]] for cls_id in tk_cls_ids
+ ]
+ tracking_outputs[output_keys[4]] = tk_cls_names
+ return tracking_outputs
diff --git a/paddlecv/ppcv/ops/connector/tracker/tracker.py b/paddlecv/ppcv/ops/connector/tracker/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..22563f019f850a4fd0887271107df717c35b47a1
--- /dev/null
+++ b/paddlecv/ppcv/ops/connector/tracker/tracker.py
@@ -0,0 +1,394 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/ocsort.py
+"""
+
+import numpy as np
+try:
+ from filterpy.kalman import KalmanFilter
+except:
+ print(
+ 'Warning: Unable to use OC-SORT, please install filterpy, for example: `pip install filterpy`, see https://github.com/rlabbe/filterpy'
+ )
+ pass
+
+from .matching.ocsort_matching import associate, linear_assignment, iou_batch
+
+__all__ = ['OCSORTTracker']
+
+
+def k_previous_obs(observations, cur_age, k):
+ if len(observations) == 0:
+ return [-1, -1, -1, -1, -1]
+ for i in range(k):
+ dt = k - i
+ if cur_age - dt in observations:
+ return observations[cur_age - dt]
+ max_age = max(observations.keys())
+ return observations[max_age]
+
+
+def convert_bbox_to_z(bbox):
+ """
+ Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
+ [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
+ the aspect ratio
+ """
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ x = bbox[0] + w / 2.
+ y = bbox[1] + h / 2.
+ s = w * h # scale is just area
+ r = w / float(h + 1e-6)
+ return np.array([x, y, s, r]).reshape((4, 1))
+
+
+def convert_x_to_bbox(x, score=None):
+ """
+ Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
+ [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
+ """
+ w = np.sqrt(x[2] * x[3])
+ h = x[2] / w
+ if (score == None):
+ return np.array(
+ [x[0] - w / 2., x[1] - h / 2., x[0] + w / 2.,
+ x[1] + h / 2.]).reshape((1, 4))
+ else:
+ score = np.array([score])
+ return np.array([
+ x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score
+ ]).reshape((1, 5))
+
+
+def speed_direction(bbox1, bbox2):
+ cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
+ cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
+ speed = np.array([cy2 - cy1, cx2 - cx1])
+ norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1)**2) + 1e-6
+ return speed / norm
+
+
+class KalmanBoxTracker(object):
+ """
+ This class represents the internal state of individual tracked objects observed as bbox.
+
+ Args:
+ bbox (np.array): bbox in [x1,y1,x2,y2,score] format.
+ delta_t (int): delta_t of previous observation
+ """
+ count = 0
+
+ def __init__(self, bbox, delta_t=3):
+ try:
+ from filterpy.kalman import KalmanFilter
+ except Exception as e:
+ raise RuntimeError(
+ 'Unable to use OC-SORT, please install filterpy, for example: `pip install filterpy`, see https://github.com/rlabbe/filterpy'
+ )
+ self.kf = KalmanFilter(dim_x=7, dim_z=4)
+ self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0],
+ [0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 1]])
+ self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
+ self.kf.R[2:, 2:] *= 10.
+ self.kf.P[4:, 4:] *= 1000.
+ # give high uncertainty to the unobservable initial velocities
+ self.kf.P *= 10.
+ self.kf.Q[-1, -1] *= 0.01
+ self.kf.Q[4:, 4:] *= 0.01
+
+ self.score = bbox[4]
+ self.kf.x[:4] = convert_bbox_to_z(bbox)
+ self.time_since_update = 0
+ self.id = KalmanBoxTracker.count
+ KalmanBoxTracker.count += 1
+ self.history = []
+ self.hits = 0
+ self.hit_streak = 0
+ self.age = 0
+ """
+ NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
+ function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
+ fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
+ """
+ self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
+ self.observations = dict()
+ self.history_observations = []
+ self.velocity = None
+ self.delta_t = delta_t
+
+ def update(self, bbox):
+ """
+ Updates the state vector with observed bbox.
+ """
+ if bbox is not None:
+ if self.last_observation.sum() >= 0: # no previous observation
+ previous_box = None
+ for i in range(self.delta_t):
+ dt = self.delta_t - i
+ if self.age - dt in self.observations:
+ previous_box = self.observations[self.age - dt]
+ break
+ if previous_box is None:
+ previous_box = self.last_observation
+ """
+ Estimate the track speed direction with observations \Delta t steps away
+ """
+ self.velocity = speed_direction(previous_box, bbox)
+ """
+ Insert new observations. This is a ugly way to maintain both self.observations
+ and self.history_observations. Bear it for the moment.
+ """
+ self.last_observation = bbox
+ self.observations[self.age] = bbox
+ self.history_observations.append(bbox)
+
+ self.time_since_update = 0
+ self.history = []
+ self.hits += 1
+ self.hit_streak += 1
+ self.kf.update(convert_bbox_to_z(bbox))
+ else:
+ self.kf.update(bbox)
+
+ def predict(self):
+ """
+ Advances the state vector and returns the predicted bounding box estimate.
+ """
+ if ((self.kf.x[6] + self.kf.x[2]) <= 0):
+ self.kf.x[6] *= 0.0
+
+ self.kf.predict()
+ self.age += 1
+ if (self.time_since_update > 0):
+ self.hit_streak = 0
+ self.time_since_update += 1
+ self.history.append(convert_x_to_bbox(self.kf.x, score=self.score))
+ return self.history[-1]
+
+ def get_state(self):
+ return convert_x_to_bbox(self.kf.x, score=self.score)
+
+
+class OCSORTTracker(object):
+ """
+ OCSORT tracker, support single class
+
+ Args:
+ det_thresh (float): threshold of detection score
+ max_age (int): maximum number of missed misses before a track is deleted
+ min_hits (int): minimum hits for associate
+ iou_threshold (float): iou threshold for associate
+ delta_t (int): delta_t of previous observation
+ inertia (float): vdc_weight of angle_diff_cost for associate
+ vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
+ bad results. If set <= 0 means no need to filter bboxes,usually set
+ 1.6 for pedestrian tracking.
+ min_box_area (int): min box area to filter out low quality boxes
+ use_byte (bool): Whether use ByteTracker, default False
+ """
+
+ def __init__(self,
+ det_thresh=0.6,
+ max_age=30,
+ min_hits=3,
+ iou_threshold=0.3,
+ delta_t=3,
+ inertia=0.2,
+ vertical_ratio=-1,
+ min_box_area=0,
+ use_byte=False):
+ self.det_thresh = det_thresh
+ self.max_age = max_age
+ self.min_hits = min_hits
+ self.iou_threshold = iou_threshold
+ self.delta_t = delta_t
+ self.inertia = inertia
+ self.vertical_ratio = vertical_ratio
+ self.min_box_area = min_box_area
+ self.use_byte = use_byte
+
+ self.trackers = []
+ self.frame_count = 0
+ KalmanBoxTracker.count = 0
+
+ def update(self, pred_dets, pred_embs=None):
+ """
+ Args:
+ pred_dets (np.array): Detection results of the image, the shape is
+ [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
+ pred_embs (np.array): Embedding results of the image, the shape is
+ [N, 128] or [N, 512], default as None.
+
+ Return:
+ tracking boxes (np.array): [M, 6], means 'x0, y0, x1, y1, score, id'.
+ """
+ if pred_dets is None:
+ return np.empty((0, 6))
+
+ self.frame_count += 1
+
+ bboxes = pred_dets[:, 2:]
+ scores = pred_dets[:, 1:2]
+ dets = np.concatenate((bboxes, scores), axis=1)
+ scores = scores.squeeze(-1)
+
+ inds_low = scores > 0.1
+ inds_high = scores < self.det_thresh
+ inds_second = np.logical_and(inds_low, inds_high)
+ # self.det_thresh > score > 0.1, for second matching
+ dets_second = dets[inds_second] # detections for second matching
+ remain_inds = scores > self.det_thresh
+ dets = dets[remain_inds]
+
+ # get predicted locations from existing trackers.
+ trks = np.zeros((len(self.trackers), 5))
+ to_del = []
+ ret = []
+ for t, trk in enumerate(trks):
+ pos = self.trackers[t].predict()[0]
+ trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
+ if np.any(np.isnan(pos)):
+ to_del.append(t)
+ trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
+ for t in reversed(to_del):
+ self.trackers.pop(t)
+
+ velocities = np.array([
+ trk.velocity if trk.velocity is not None else np.array((0, 0))
+ for trk in self.trackers
+ ])
+ last_boxes = np.array([trk.last_observation for trk in self.trackers])
+ k_observations = np.array([
+ k_previous_obs(trk.observations, trk.age, self.delta_t)
+ for trk in self.trackers
+ ])
+ """
+ First round of association
+ """
+ matched, unmatched_dets, unmatched_trks = associate(
+ dets, trks, self.iou_threshold, velocities, k_observations,
+ self.inertia)
+ for m in matched:
+ self.trackers[m[1]].update(dets[m[0], :])
+ """
+ Second round of associaton by OCR
+ """
+ # BYTE association
+ if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[
+ 0] > 0:
+ u_trks = trks[unmatched_trks]
+ iou_left = iou_batch(
+ dets_second,
+ u_trks) # iou between low score detections and unmatched tracks
+ iou_left = np.array(iou_left)
+ if iou_left.max() > self.iou_threshold:
+ """
+ NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
+ get a higher performance especially on MOT17/MOT20 datasets. But we keep it
+ uniform here for simplicity
+ """
+ matched_indices = linear_assignment(-iou_left)
+ to_remove_trk_indices = []
+ for m in matched_indices:
+ det_ind, trk_ind = m[0], unmatched_trks[m[1]]
+ if iou_left[m[0], m[1]] < self.iou_threshold:
+ continue
+ self.trackers[trk_ind].update(dets_second[det_ind, :])
+ to_remove_trk_indices.append(trk_ind)
+ unmatched_trks = np.setdiff1d(unmatched_trks,
+ np.array(to_remove_trk_indices))
+
+ if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
+ left_dets = dets[unmatched_dets]
+ left_trks = last_boxes[unmatched_trks]
+ iou_left = iou_batch(left_dets, left_trks)
+ iou_left = np.array(iou_left)
+ if iou_left.max() > self.iou_threshold:
+ """
+ NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
+ get a higher performance especially on MOT17/MOT20 datasets. But we keep it
+ uniform here for simplicity
+ """
+ rematched_indices = linear_assignment(-iou_left)
+ to_remove_det_indices = []
+ to_remove_trk_indices = []
+ for m in rematched_indices:
+ det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[
+ 1]]
+ if iou_left[m[0], m[1]] < self.iou_threshold:
+ continue
+ self.trackers[trk_ind].update(dets[det_ind, :])
+ to_remove_det_indices.append(det_ind)
+ to_remove_trk_indices.append(trk_ind)
+ unmatched_dets = np.setdiff1d(unmatched_dets,
+ np.array(to_remove_det_indices))
+ unmatched_trks = np.setdiff1d(unmatched_trks,
+ np.array(to_remove_trk_indices))
+
+ for m in unmatched_trks:
+ self.trackers[m].update(None)
+
+ # create and initialise new trackers for unmatched detections
+ for i in unmatched_dets:
+ trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t)
+ self.trackers.append(trk)
+ i = len(self.trackers)
+ for trk in reversed(self.trackers):
+ if trk.last_observation.sum() < 0:
+ d = trk.get_state()[0]
+ else:
+ d = trk.last_observation # tlbr + score
+ if (trk.time_since_update < 1) and (
+ trk.hit_streak >= self.min_hits or
+ self.frame_count <= self.min_hits):
+ # +1 as MOT benchmark requires positive
+ ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))
+ i -= 1
+ # remove dead tracklet
+ if (trk.time_since_update > self.max_age):
+ self.trackers.pop(i)
+ if (len(ret) > 0):
+ return np.concatenate(ret)
+ return np.empty((0, 6))
+
+ def tracking(self, pred_dets, pred_embs, output_keys):
+ online_targets = self.update(pred_dets, pred_embs)
+ tracking_bboxes, tracking_scores = [], []
+ tracking_ids, tracking_cls_ids = [], []
+ for t in online_targets:
+ x1, y1, x2, y2 = t[:4]
+ w, h = x2 - x1, y2 - y1
+ tscore = float(t[4])
+ tid = int(t[5])
+ if w * h <= self.min_box_area: continue
+ if self.vertical_ratio > 0 and w / h > self.vertical_ratio:
+ continue
+ if w * h > 0:
+ tracking_bboxes.append([x1, y1, x2, y2])
+ tracking_scores.append(tscore)
+ tracking_ids.append(tid)
+ # only support 1 class now
+ tracking_cls_ids.append(0)
+ tracking_outs = {
+ output_keys[0]: tracking_bboxes,
+ output_keys[1]: tracking_scores,
+ output_keys[2]: tracking_ids,
+ output_keys[3]: tracking_cls_ids
+ }
+ return tracking_outs
diff --git a/paddlecv/ppcv/ops/general_data_obj.py b/paddlecv/ppcv/ops/general_data_obj.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f07101f48a3791d29461b2c693933b680f3d708
--- /dev/null
+++ b/paddlecv/ppcv/ops/general_data_obj.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numby as np
+
+
+class GeneralDataObj(object):
+ def __init__(self, data):
+ assert isinstance(data, (dict, ))
+ self.data_dict = data
+ pass
+
+ def get(self, key):
+ """
+ key can be one of [list, tuple, str]
+ """
+ if isinstance(key, (list, tuple)):
+ return [self.data_dict[k] for k in key]
+ elif isinstance(key, (str)):
+ return self.data_dict[key]
+ else:
+ assert False, f"key({key}) type must be in on of [list, tuple, str] but got {type(key)}"
+
+ def set(self, key, value):
+ """
+ key: str
+ value: an object
+ """
+ self.data_dict[key] = value
+
+ def keys(self, ):
+ """
+ get all keys of the data
+ """
+ return list(self.data_dict.keys())
diff --git a/paddlecv/ppcv/ops/models/__init__.py b/paddlecv/ppcv/ops/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae990b711991aab8e3ee445ffd235e1628100a2
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import classification
+from . import detection
+from . import keypoint
+from . import ocr
+from . import segmentation
+
+from .classification import *
+from .feature_extraction import *
+from .detection import *
+from .keypoint import *
+from .segmentation import *
+from .ocr import *
+
+__all__ = classification.__all__ + detection.__all__ + keypoint.__all__
+__all__ += segmentation.__all__
+__all__ += ocr.__all__
diff --git a/paddlecv/ppcv/ops/models/base.py b/paddlecv/ppcv/ops/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..779f03906ef62653187516bfbb1e78e605effe37
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/base.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import importlib
+import math
+import numpy as np
+import paddle
+from paddle.inference import Config
+from paddle.inference import create_predictor
+
+from ppcv.ops.base import BaseOp
+from ppcv.ops.predictor import PaddlePredictor
+from ppcv.utils.download import get_model_path
+
+
+class ModelBaseOp(BaseOp):
+ """
+ Base Operator, implement of prediction process
+ Args
+ """
+
+ def __init__(self, model_cfg, env_cfg):
+ super(ModelBaseOp, self).__init__(model_cfg, env_cfg)
+ param_path = get_model_path(model_cfg['param_path'])
+ model_path = get_model_path(model_cfg['model_path'])
+ env_cfg["batch_size"] = model_cfg.get("batch_size", 1)
+ delete_pass = model_cfg.get("delete_pass", [])
+ self.batch_size = env_cfg["batch_size"]
+ self.name = model_cfg["name"]
+ self.frame = -1
+ self.predictor = PaddlePredictor(param_path, model_path, env_cfg,
+ delete_pass)
+ self.input_names = self.predictor.get_input_names()
+
+ keys = self.get_output_keys()
+ self.output_keys = [self.name + '.' + key for key in keys]
+
+ @classmethod
+ def type(self):
+ return 'MODEL'
+
+ def preprocess(self, inputs):
+ raise NotImplementedError
+
+ def postprocess(self, inputs):
+ raise NotImplementedError
diff --git a/paddlecv/ppcv/ops/models/classification/__init__.py b/paddlecv/ppcv/ops/models/classification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..848884c74308f22fab8e3f27b803a4837652e52f
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/classification/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import ClassificationOp
+
+__all__ = ['ClassificationOp']
diff --git a/paddlecv/ppcv/ops/models/classification/inference.py b/paddlecv/ppcv/ops/models/classification/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..863f293c9db6f6ae198022c580b6e3fe556e9592
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/classification/inference.py
@@ -0,0 +1,109 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+from ..base import ModelBaseOp
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class ClassificationOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(ClassificationOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["class_ids", "scores", "label_names"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, inputs, result):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs, self.output_keys)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def infer(self, image_list):
+ inputs = []
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ inputs = [self.preprocess(img) for img in batch_image_list]
+ inputs = np.concatenate(inputs, axis=0)
+ # model inference
+ result = self.predictor.run(inputs)[0]
+ # postprocess
+ result = self.postprocess(inputs, result)
+ results.extend(result)
+ # results = self.merge_batch_result(results)
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ key = self.input_keys[0]
+ is_list = False
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ is_list = True
+ else:
+ inputs = [[input[key]] for input in inputs]
+ sub_index_list = [len(input) for input in inputs]
+ inputs = reduce(lambda x, y: x.extend(y) or x, inputs)
+
+ # step2: run
+ outputs = self.infer(inputs)
+
+ # step3: merge
+ curr_offsef_id = 0
+ pipe_outputs = []
+ for idx in range(len(sub_index_list)):
+ sub_start_idx = curr_offsef_id
+ sub_end_idx = curr_offsef_id + sub_index_list[idx]
+ output = outputs[sub_start_idx:sub_end_idx]
+ output = {k: [o[k] for o in output] for k in output[0]}
+ if is_list is not True:
+ output = {k: output[k][0] for k in output}
+ pipe_outputs.append(output)
+
+ curr_offsef_id = sub_end_idx
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/classification/postprocess.py b/paddlecv/ppcv/ops/models/classification/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb9e225434759c18811532fc6dc9cb430c8c5ff
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/classification/postprocess.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+from ppcv.utils.download import get_dict_path
+
+from ppcv.utils.logger import setup_logger
+
+logger = setup_logger('Classificaion')
+
+
+class Topk(object):
+ def __init__(self, topk=1, class_id_map_file=None):
+ assert isinstance(topk, (int, ))
+ class_id_map_file = get_dict_path(class_id_map_file)
+ self.class_id_map = self.parse_class_id_map(class_id_map_file)
+ self.topk = topk
+
+ def parse_class_id_map(self, class_id_map_file):
+ if class_id_map_file is None:
+ return None
+ file_path = get_dict_path(class_id_map_file)
+
+ try:
+ class_id_map = {}
+ with open(file_path, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ partition = line.split("\n")[0].partition(" ")
+ class_id_map[int(partition[0])] = str(partition[-1])
+ except Exception as ex:
+ msg = f"Error encountered while loading the class_id_map_file. The related setting has been ignored. The detailed error info: {ex}"
+ logger.warning(msg)
+ class_id_map = None
+ return class_id_map
+
+ def __call__(self, x, output_keys):
+ y = []
+ for idx, probs in enumerate(x):
+ index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32")
+ clas_id_list = []
+ score_list = []
+ label_name_list = []
+ for i in index:
+ clas_id_list.append(i.item())
+ score_list.append(probs[i].item())
+ if self.class_id_map is not None:
+ label_name_list.append(self.class_id_map[i.item()])
+ result = {
+ output_keys[0]: clas_id_list,
+ output_keys[1]: np.around(
+ score_list, decimals=5).tolist(),
+ }
+ if label_name_list is not None:
+ result[output_keys[2]] = label_name_list
+ y.append(result)
+ return y
diff --git a/paddlecv/ppcv/ops/models/classification/preprocess.py b/paddlecv/ppcv/ops/models/classification/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1562e9ce156dd627d035d58683ddf617c4fbd7
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/classification/preprocess.py
@@ -0,0 +1,343 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from functools import partial
+import six
+import math
+import random
+import cv2
+import numpy as np
+import importlib
+from PIL import Image
+
+__all__ = [
+ "UnifiedResize", "DecodeImage", "ResizeImage", "CropImage",
+ "RandCropImage", "RandFlipImage", "NormalizeImage", "ToCHWImage",
+ "ExpandDim"
+]
+
+
+class UnifiedResize(object):
+ def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
+ _cv2_interp_from_str = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'area': cv2.INTER_AREA,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'lanczos': cv2.INTER_LANCZOS4,
+ 'random': (cv2.INTER_LINEAR, cv2.INTER_CUBIC)
+ }
+ _pil_interp_from_str = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING,
+ 'random': (Image.BILINEAR, Image.BICUBIC)
+ }
+
+ def _cv2_resize(src, size, resample):
+ if isinstance(resample, tuple):
+ resample = random.choice(resample)
+ return cv2.resize(src, size, interpolation=resample)
+
+ def _pil_resize(src, size, resample, return_numpy=True):
+ if isinstance(resample, tuple):
+ resample = random.choice(resample)
+ if isinstance(src, np.ndarray):
+ pil_img = Image.fromarray(src)
+ else:
+ pil_img = src
+ pil_img = pil_img.resize(size, resample)
+ if return_numpy:
+ return np.asarray(pil_img)
+ return pil_img
+
+ if backend.lower() == "cv2":
+ if isinstance(interpolation, str):
+ interpolation = _cv2_interp_from_str[interpolation.lower()]
+ # compatible with opencv < version 4.4.0
+ elif interpolation is None:
+ interpolation = cv2.INTER_LINEAR
+ self.resize_func = partial(_cv2_resize, resample=interpolation)
+ elif backend.lower() == "pil":
+ if isinstance(interpolation, str):
+ interpolation = _pil_interp_from_str[interpolation.lower()]
+ self.resize_func = partial(
+ _pil_resize, resample=interpolation, return_numpy=return_numpy)
+ else:
+ logger.warning(
+ f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
+ )
+ self.resize_func = cv2.resize
+
+ def __call__(self, src, size):
+ if isinstance(size, list):
+ size = tuple(size)
+ return self.resize_func(src, size)
+
+
+class OperatorParamError(ValueError):
+ """ OperatorParamError
+ """
+ pass
+
+
+class DecodeImage(object):
+ """ decode image """
+
+ def __init__(self, to_rgb=True, to_np=False, channel_first=False):
+ self.to_rgb = to_rgb
+ self.to_np = to_np # to numpy
+ self.channel_first = channel_first # only enabled when to_np is True
+
+ def __call__(self, img):
+ if six.PY2:
+ assert type(img) is str and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ else:
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ data = np.frombuffer(img, dtype='uint8')
+ img = cv2.imdecode(data, 1)
+ if self.to_rgb:
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
+ img.shape)
+ img = img[:, :, ::-1]
+
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+
+ return img
+
+
+class ResizeImage(object):
+ """ resize image """
+
+ def __init__(self,
+ size=None,
+ resize_short=None,
+ interpolation=None,
+ backend="cv2",
+ return_numpy=True):
+ if resize_short is not None and resize_short > 0:
+ self.resize_short = resize_short
+ self.w = None
+ self.h = None
+ elif size is not None:
+ self.resize_short = None
+ self.w = size if type(size) is int else size[0]
+ self.h = size if type(size) is int else size[1]
+ else:
+ raise OperatorParamError("invalid params for ReisizeImage for '\
+ 'both 'size' and 'resize_short' are None")
+
+ self._resize_func = UnifiedResize(
+ interpolation=interpolation,
+ backend=backend,
+ return_numpy=return_numpy)
+
+ def __call__(self, img):
+ if isinstance(img, np.ndarray):
+ # numpy input
+ img_h, img_w = img.shape[:2]
+ else:
+ # PIL image input
+ img_w, img_h = img.size
+
+ if self.resize_short is not None:
+ percent = float(self.resize_short) / min(img_w, img_h)
+ w = int(round(img_w * percent))
+ h = int(round(img_h * percent))
+ else:
+ w = self.w
+ h = self.h
+ return self._resize_func(img, (w, h))
+
+
+class CropImage(object):
+ """ crop image """
+
+ def __init__(self, size):
+ if type(size) is int:
+ self.size = (size, size)
+ else:
+ self.size = size # (h, w)
+
+ def __call__(self, img):
+ w, h = self.size
+ img_h, img_w = img.shape[:2]
+
+ if img_h < h or img_w < w:
+ raise Exception(
+ f"The size({h}, {w}) of CropImage must be greater than size({img_h}, {img_w}) of image. Please check image original size and size of ResizeImage if used."
+ )
+
+ w_start = (img_w - w) // 2
+ h_start = (img_h - h) // 2
+
+ w_end = w_start + w
+ h_end = h_start + h
+ return img[h_start:h_end, w_start:w_end, :]
+
+
+class RandCropImage(object):
+ """ random crop image """
+
+ def __init__(self,
+ size,
+ scale=None,
+ ratio=None,
+ interpolation=None,
+ backend="cv2"):
+ if type(size) is int:
+ self.size = (size, size) # (h, w)
+ else:
+ self.size = size
+
+ self.scale = [0.08, 1.0] if scale is None else scale
+ self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
+
+ self._resize_func = UnifiedResize(
+ interpolation=interpolation, backend=backend)
+
+ def __call__(self, img):
+ size = self.size
+ scale = self.scale
+ ratio = self.ratio
+
+ aspect_ratio = math.sqrt(random.uniform(*ratio))
+ w = 1. * aspect_ratio
+ h = 1. / aspect_ratio
+
+ img_h, img_w = img.shape[:2]
+
+ bound = min((float(img_w) / img_h) / (w**2),
+ (float(img_h) / img_w) / (h**2))
+ scale_max = min(scale[1], bound)
+ scale_min = min(scale[0], bound)
+
+ target_area = img_w * img_h * random.uniform(scale_min, scale_max)
+ target_size = math.sqrt(target_area)
+ w = int(target_size * w)
+ h = int(target_size * h)
+
+ i = random.randint(0, img_w - w)
+ j = random.randint(0, img_h - h)
+
+ img = img[j:j + h, i:i + w, :]
+
+ return self._resize_func(img, size)
+
+
+class RandFlipImage(object):
+ """ random flip image
+ flip_code:
+ 1: Flipped Horizontally
+ 0: Flipped Vertically
+ -1: Flipped Horizontally & Vertically
+ """
+
+ def __init__(self, flip_code=1):
+ assert flip_code in [-1, 0, 1
+ ], "flip_code should be a value in [-1, 0, 1]"
+ self.flip_code = flip_code
+
+ def __call__(self, img):
+ if random.randint(0, 1) == 1:
+ return cv2.flip(img, self.flip_code)
+ else:
+ return img
+
+
+class NormalizeImage(object):
+ """ normalize image such as substract mean, divide std
+ """
+
+ def __init__(self,
+ scale=None,
+ mean=None,
+ std=None,
+ order='chw',
+ output_fp16=False,
+ channel_num=3):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ assert channel_num in [
+ 3, 4
+ ], "channel number of input image should be set to 3 or 4."
+ self.channel_num = channel_num
+ self.output_dtype = 'float16' if output_fp16 else 'float32'
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ self.order = order
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype('float32')
+ self.std = np.array(std).reshape(shape).astype('float32')
+
+ def __call__(self, img):
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+
+ assert isinstance(img,
+ np.ndarray), "invalid input 'img' in NormalizeImage"
+
+ img = (img.astype('float32') * self.scale - self.mean) / self.std
+
+ if self.channel_num == 4:
+ img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
+ img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
+ pad_zeros = np.zeros(
+ (1, img_h, img_w)) if self.order == 'chw' else np.zeros(
+ (img_h, img_w, 1))
+ img = (np.concatenate(
+ (img, pad_zeros), axis=0)
+ if self.order == 'chw' else np.concatenate(
+ (img, pad_zeros), axis=2))
+ return img.astype(self.output_dtype)
+
+
+class ToCHWImage(object):
+ """ convert hwc image to chw image
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, img):
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+
+ return img.transpose((2, 0, 1))
+
+
+class ExpandDim(object):
+ def __init__(self, axis=0):
+ self.axis = axis
+
+ def __call__(self, img):
+ img = np.expand_dims(img, axis=self.axis)
+ return img
diff --git a/paddlecv/ppcv/ops/models/detection/__init__.py b/paddlecv/ppcv/ops/models/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..969d1f2a615b41ad5e52ac282c314ac7f63c22da
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/detection/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import DetectionOp
+
+__all__ = ['DetectionOp']
diff --git a/paddlecv/ppcv/ops/models/detection/inference.py b/paddlecv/ppcv/ops/models/detection/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e987aec1410d7fca5cb5d7feb28fabbdb31fad
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/detection/inference.py
@@ -0,0 +1,159 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import importlib
+import numpy as np
+import math
+import paddle
+from ..base import ModelBaseOp
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class DetectionOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(DetectionOp, self).__init__(model_cfg, env_cfg)
+ self.model_cfg = model_cfg
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["dt_bboxes", "dt_scores", "dt_class_ids", "dt_cls_names"]
+
+ def preprocess(self, image):
+ im_info = {
+ 'scale_factor': np.array(
+ [1., 1.], dtype=np.float32),
+ 'im_shape': np.array(
+ image.shape[:2], dtype=np.float32),
+ 'input_shape': self.model_cfg["image_shape"],
+ }
+ for ops in self.preprocessor:
+ image, im_info = ops(image, im_info)
+ return image, im_info
+
+ def postprocess(self, inputs, result, bbox_num):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs, bbox_num = ops(outputs, bbox_num, self.output_keys)
+ else:
+ outputs, bbox_num = ops(outputs, bbox_num)
+ return outputs, bbox_num
+
+ def create_inputs(self, imgs, im_info):
+ inputs = {}
+ im_shape = []
+ scale_factor = []
+ if len(imgs) == 1:
+ image = np.array((imgs[0], )).astype('float32')
+ im_shape = np.array((im_info[0]['im_shape'], )).astype('float32')
+ scale_factor = np.array(
+ (im_info[0]['scale_factor'], )).astype('float32')
+ inputs = dict(
+ im_shape=im_shape, image=image, scale_factor=scale_factor)
+ outputs = [inputs[key] for key in self.input_names]
+ return outputs
+
+ for e in im_info:
+ im_shape.append(np.array((e['im_shape'], )).astype('float32'))
+ scale_factor.append(
+ np.array((e['scale_factor'], )).astype('float32'))
+
+ inputs['im_shape'] = np.concatenate(im_shape, axis=0)
+ inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
+
+ imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
+ max_shape_h = max([e[0] for e in imgs_shape])
+ max_shape_w = max([e[1] for e in imgs_shape])
+ padding_imgs = []
+ for img in imgs:
+ im_c, im_h, im_w = img.shape[:]
+ padding_im = np.zeros(
+ (im_c, max_shape_h, max_shape_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = img
+ padding_imgs.append(padding_im)
+ inputs['image'] = np.stack(padding_imgs, axis=0)
+ outputs = [inputs[key] for key in self.input_names]
+ return outputs
+
+ def infer(self, image_list):
+ inputs = []
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ bbox_nums = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ output_list = []
+ info_list = []
+ for img in batch_image_list:
+ output, info = self.preprocess(img)
+ output_list.append(output)
+ info_list.append(info)
+ inputs = self.create_inputs(output_list, info_list)
+
+ # model inference
+ result = self.predictor.run(inputs)
+ res = result[0]
+ bbox_num = result[1]
+ # postprocess
+ res, bbox_num = self.postprocess(inputs, res, bbox_num)
+ results.append(res)
+ bbox_nums.append(bbox_num)
+ # results = self.merge_batch_result(results)
+ return results, bbox_nums
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ # for the input_keys as list
+ # inputs = [pipe_input[key] for pipe_input in pipe_inputs for key in self.input_keys]
+
+ key = self.input_keys[0]
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ else:
+ inputs = [[input[key]] for input in inputs]
+ sub_index_list = [len(input) for input in inputs]
+ inputs = reduce(lambda x, y: x.extend(y) or x, inputs)
+
+ # step2: run
+ outputs, bbox_nums = self.infer(inputs)
+
+ # step3: merge
+ curr_offsef_id = 0
+ pipe_outputs = []
+ for i, bbox_num in enumerate(bbox_nums):
+ output = outputs[i]
+ start_id = 0
+ for num in bbox_num:
+ end_id = start_id + num
+ out = {k: v[start_id:end_id] for k, v in output.items()}
+ pipe_outputs.append(out)
+ start_id = end_id
+ return pipe_outputs
\ No newline at end of file
diff --git a/paddlecv/ppcv/ops/models/detection/postprocess.py b/paddlecv/ppcv/ops/models/detection/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc7390768cc6876f38f6230b0afea6782e57def9
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/detection/postprocess.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+from scipy.special import softmax
+
+from ppcv.utils.download import get_dict_path
+
+
+class ParserDetResults(object):
+ def __init__(self,
+ label_list,
+ threshold=0.5,
+ max_det_results=100,
+ keep_cls_ids=None):
+ self.threshold = threshold
+ self.max_det_results = max_det_results
+ self.clsid2catid, self.catid2name = self.get_categories(label_list)
+ self.keep_cls_ids = keep_cls_ids if keep_cls_ids else list(
+ self.clsid2catid.keys())
+
+ def get_categories(self, label_list):
+ if isinstance(label_list, list):
+ clsid2catid = {i: i for i in range(len(label_list))}
+ catid2name = {i: label_list[i] for i in range(len(label_list))}
+ return clsid2catid, catid2name
+
+ label_list = get_dict_path(label_list)
+ if label_list.endswith('json'):
+ # lazy import pycocotools here
+ from pycocotools.coco import COCO
+ coco = COCO(label_list)
+ cats = coco.loadCats(coco.getCatIds())
+ clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
+ catid2name = {cat['id']: cat['name'] for cat in cats}
+ elif label_list.endswith('txt'):
+ cats = []
+ with open(label_list) as f:
+ for line in f.readlines():
+ cats.append(line.strip())
+ if cats[0] == 'background': cats = cats[1:]
+
+ clsid2catid = {i: i for i in range(len(cats))}
+ catid2name = {i: name for i, name in enumerate(cats)}
+
+ else:
+ raise ValueError("label_list {} should be json or txt.".format(
+ label_list))
+ return clsid2catid, catid2name
+
+ def __call__(self, preds, bbox_num, output_keys):
+ start_id = 0
+ dt_bboxes = []
+ scores = []
+ class_ids = []
+ cls_names = []
+ new_bbox_num = []
+
+ for num in bbox_num:
+ end_id = start_id + num
+ pred = preds[start_id:end_id]
+ start_id = end_id
+ max_det_results = min(self.max_det_results, pred.shape[0])
+ keep_indexes = pred[:, 1].argsort()[::-1][:max_det_results]
+
+ select_num = 0
+ for idx in keep_indexes:
+ single_res = pred[idx].tolist()
+ class_id = int(single_res[0])
+ score = single_res[1]
+ bbox = single_res[2:]
+ if score < self.threshold:
+ continue
+ if class_id not in self.keep_cls_ids:
+ continue
+ select_num += 1
+ dt_bboxes.append(bbox)
+ scores.append(score)
+ class_ids.append(class_id)
+ cls_names.append(self.catid2name[self.clsid2catid[class_id]])
+ new_bbox_num.append(select_num)
+ result = {
+ output_keys[0]: dt_bboxes,
+ output_keys[1]: scores,
+ output_keys[2]: class_ids,
+ output_keys[3]: cls_names,
+ }
+ new_bbox_num = np.array(new_bbox_num).astype('int32')
+ return result, new_bbox_num
\ No newline at end of file
diff --git a/paddlecv/ppcv/ops/models/detection/preprocess.py b/paddlecv/ppcv/ops/models/detection/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f9b7de8966988910bbab9e95b9b00cadf7ce8bc
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/detection/preprocess.py
@@ -0,0 +1,225 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+
+
+def decode_image(im_file, im_info):
+ """read rgb image
+ Args:
+ im_file (str|np.ndarray): input can be image path or np.ndarray
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ if isinstance(im_file, str):
+ with open(im_file, 'rb') as f:
+ im_read = f.read()
+ data = np.frombuffer(im_read, dtype='uint8')
+ im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ else:
+ im = im_file
+ im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
+ im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
+ return im, im_info
+
+
+class Resize(object):
+ """resize image by target_size and max_size
+ Args:
+ target_size (int): the target size of image
+ keep_ratio (bool): whether keep_ratio or not, default true
+ interp (int): method of resize
+ """
+
+ def __init__(
+ self,
+ target_size,
+ keep_ratio=True,
+ interp=cv2.INTER_LINEAR, ):
+ if isinstance(target_size, int):
+ target_size = [target_size, target_size]
+ self.target_size = target_size
+ self.keep_ratio = keep_ratio
+ self.interp = interp
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ assert len(self.target_size) == 2
+ assert self.target_size[0] > 0 and self.target_size[1] > 0
+ im_channel = im.shape[2]
+ im_scale_y, im_scale_x = self.generate_scale(im)
+ # set image_shape
+ im_info['input_shape'][1] = int(im_scale_y * im.shape[0])
+ im_info['input_shape'][2] = int(im_scale_x * im.shape[1])
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=self.interp)
+ im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
+ im_info['scale_factor'] = np.array(
+ [im_scale_y, im_scale_x]).astype('float32')
+ return im, im_info
+
+ def generate_scale(self, im):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ Returns:
+ im_scale_x: the resize ratio of X
+ im_scale_y: the resize ratio of Y
+ """
+ origin_shape = im.shape[:2]
+ im_c = im.shape[2]
+ if self.keep_ratio:
+ im_size_min = np.min(origin_shape)
+ im_size_max = np.max(origin_shape)
+ target_size_min = np.min(self.target_size)
+ target_size_max = np.max(self.target_size)
+ im_scale = float(target_size_min) / float(im_size_min)
+ if np.round(im_scale * im_size_max) > target_size_max:
+ im_scale = float(target_size_max) / float(im_size_max)
+ im_scale_x = im_scale
+ im_scale_y = im_scale
+ else:
+ resize_h, resize_w = self.target_size
+ im_scale_y = resize_h / float(origin_shape[0])
+ im_scale_x = resize_w / float(origin_shape[1])
+ return im_scale_y, im_scale_x
+
+
+class NormalizeImage(object):
+ """normalize image
+ Args:
+ mean (list): im - mean
+ std (list): im / std
+ is_scale (bool): whether need im / 255
+ is_channel_first (bool): if True: image shape is CHW, else: HWC
+ """
+
+ def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
+ self.mean = mean
+ self.std = std
+ self.is_scale = is_scale
+ self.norm_type = norm_type
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.astype(np.float32, copy=False)
+ if self.is_scale:
+ scale = 1.0 / 255.0
+ im *= scale
+
+ if self.norm_type == 'mean_std':
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+ im -= mean
+ im /= std
+
+ return im, im_info
+
+
+class Permute(object):
+ """permute image
+ Args:
+ to_bgr (bool): whether convert RGB to BGR
+ channel_first (bool): whether convert HWC to CHW
+ """
+
+ def __init__(self, ):
+ super().__init__()
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.transpose((2, 0, 1)).copy()
+ return im, im_info
+
+
+class PadStride(object):
+ """ padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
+ Args:
+ stride (bool): model with FPN need image shape % stride == 0
+ """
+
+ def __init__(self, stride=0):
+ self.coarsest_stride = stride
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ coarsest_stride = self.coarsest_stride
+ if coarsest_stride <= 0:
+ return im, im_info
+ im_c, im_h, im_w = im.shape
+ pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
+ pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
+ padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = im
+ return padding_im, im_info
+
+
+class RGB2BGR(object):
+ """permute image
+ Args:
+ to_bgr (bool): whether convert RGB to BGR
+ channel_first (bool): whether convert HWC to CHW
+ """
+
+ def __init__(self, ):
+ super().__init__()
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im[:, :, ::-1]
+ return im, im_info
diff --git a/paddlecv/ppcv/ops/models/feature_extraction/__init__.py b/paddlecv/ppcv/ops/models/feature_extraction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..877cbd169051c7ea6378f2a8e71f072483309537
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/feature_extraction/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import FeatureExtractionOp
+
+__all__ = ['FeatureExtractionOp']
diff --git a/paddlecv/ppcv/ops/models/feature_extraction/inference.py b/paddlecv/ppcv/ops/models/feature_extraction/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..37f93b7cb9119b9a19c61a030ed04c9a2109cd99
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/feature_extraction/inference.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+from ..base import ModelBaseOp
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class FeatureExtractionOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super().__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["dt_bboxes", "feature", "rec_score", "rec_doc"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, output_list, bbox_list):
+ assert len(output_list) == len(bbox_list)
+ if len(output_list) == 0:
+ return {k: None for k in self.output_keys}
+ output_dict = {
+ self.output_keys[0]: bbox_list,
+ self.output_keys[1]: output_list
+ }
+ for idx, ops in enumerate(self.postprocessor):
+ output_dict = ops(output_dict, self.output_keys)
+ return output_dict
+
+ def infer_img(self, input):
+ # predict the full input image
+ img = input[self.input_keys[0]]
+ h, w = img.shape[:2]
+ image_list = [img]
+ bbox_list = [[0, 0, w, h]]
+ # for cropped image from object detection
+ if len(self.input_keys) == 3:
+ image_list.extend(input[self.input_keys[1]])
+ bbox_list.extend(input[self.input_keys[2]])
+
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ output_list = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ inputs = [self.preprocess(img) for img in batch_image_list]
+ inputs = np.concatenate(inputs, axis=0)
+ # model inference
+ output = self.predictor.run(inputs)[0]
+ output_list.extend(output)
+ # postprocess
+ return self.postprocess(output_list, bbox_list)
+
+ def __call__(self, inputs):
+ outputs = []
+ for input in inputs:
+ output = self.infer_img(input)
+ outputs.append(output)
+ return outputs
diff --git a/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py b/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd687bbb257d356295caa635c6ef4855d4ffcaf2
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/feature_extraction/postprocess.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import numpy as np
+import faiss
+import pickle
+
+from ppcv.utils.logger import setup_logger
+
+logger = setup_logger('FeatureExtraction')
+
+
+class NormalizeFeature(object):
+ def __init__(self):
+ super().__init__()
+
+ def __call__(self, outputs, output_keys=None):
+ features = outputs[output_keys[1]]
+ feas_norm = np.sqrt(np.sum(np.square(features), axis=1, keepdims=True))
+ features = np.divide(features, feas_norm)
+ outputs[output_keys[1]] = features
+ return outputs
+
+
+class Index(object):
+ def __init__(self,
+ index_method,
+ index_dir,
+ dist_type,
+ hamming_radius=None,
+ score_thres=None):
+ vector_path = os.path.join(index_dir, "vector.index")
+ id_map_path = os.path.join(index_dir, "id_map.pkl")
+ if not os.path.exists(vector_path) or not os.path.exists(id_map_path):
+ msg = "The directory \"index_dir\" must contain files \"vector.index\", and \"id_map.pkl\". Please check again!"
+ logger.error(msg)
+ raise Exception(msg)
+
+ if dist_type == "hamming":
+ self.searcher = faiss.read_index_binary(vector_path)
+ else:
+ self.searcher = faiss.read_index(vector_path)
+
+ with open(id_map_path, "rb") as fd:
+ self.id_map = pickle.load(fd)
+
+ self.dist_type = dist_type
+ self.hamming_radius = hamming_radius
+ self.score_thres = score_thres
+
+ def thresh_by_score(self, output_dict, scores):
+ threshed_outputs = {}
+ for key in output_dict:
+ threshed_outputs[key] = []
+ for idx, score in enumerate(scores):
+ if (self.dist_type == "hamming" and
+ score <= self.hamming_radius) or (
+ self.dist_type != "hamming" and
+ score >= self.score_thres):
+ for key in output_dict:
+ threshed_outputs[key].append(output_dict[key][idx])
+
+ return threshed_outputs
+
+ def __call__(self, outputs, output_keys):
+ features = outputs[output_keys[1]]
+ scores, doc_ids = self.searcher.search(features, 1)
+ docs = [self.id_map[id[0]].split()[1] for id in doc_ids]
+ outputs[output_keys[2]] = [score[0] for score in scores]
+ outputs[output_keys[3]] = docs
+
+ return self.thresh_by_score(outputs, scores)
+
+
+class NMS4Rec(object):
+ def __init__(self, thresh):
+ super().__init__()
+ self.thresh = thresh
+
+ def __call__(self, outputs, output_keys):
+ bbox_list = outputs[output_keys[0]]
+ x1 = np.array([bbox[0] for bbox in bbox_list])
+ y1 = np.array([bbox[1] for bbox in bbox_list])
+ x2 = np.array([bbox[2] for bbox in bbox_list])
+ y2 = np.array([bbox[3] for bbox in bbox_list])
+ scores = np.array(outputs[output_keys[2]])
+
+ filtered_outputs = {key: [] for key in output_keys}
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+ while order.size > 0:
+ i = order[0]
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+ inds = np.where(ovr <= self.thresh)[0]
+ order = order[inds + 1]
+
+ for key in output_keys:
+ filtered_outputs[key].append(outputs[key][i])
+
+ return filtered_outputs
diff --git a/paddlecv/ppcv/ops/models/feature_extraction/preprocess.py b/paddlecv/ppcv/ops/models/feature_extraction/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f44987750f7e7d06ca02c3e37eb142a55bc52cd
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/feature_extraction/preprocess.py
@@ -0,0 +1,17 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from ..classification.preprocess import *
\ No newline at end of file
diff --git a/paddlecv/ppcv/ops/models/keypoint/__init__.py b/paddlecv/ppcv/ops/models/keypoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd1712d967a9084865891a3b4010d65f59a73fc4
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/keypoint/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import KeypointOp
+
+__all__ = ['KeypointOp']
diff --git a/paddlecv/ppcv/ops/models/keypoint/inference.py b/paddlecv/ppcv/ops/models/keypoint/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..24e1cec9a176e10390c258c29f661a00409d1d2c
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/keypoint/inference.py
@@ -0,0 +1,146 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import importlib
+import numpy as np
+import math
+import paddle
+from ..base import ModelBaseOp
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class KeypointOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(KeypointOp, self).__init__(model_cfg, env_cfg)
+ self.model_cfg = model_cfg
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["keypoints", "kpt_scores"]
+
+ def preprocess(self, image):
+ im_info = {
+ 'im_shape': np.array(
+ image.shape[:2], dtype=np.float32),
+ 'input_shape': self.model_cfg["image_shape"],
+ }
+ for ops in self.preprocessor:
+ image, im_info = ops(image, im_info)
+ return image, im_info
+
+ def postprocess(self, inputs, im_shape, result):
+ np_heatmap = result[0]
+ im_shape = im_shape[:, ::-1]
+ center = np.round(im_shape / 2.)
+ scale = im_shape / 200.
+ outputs = self.postprocessor[0](np_heatmap, center, scale,
+ self.output_keys)
+ return outputs
+
+ def create_inputs(self, imgs, im_info):
+ inputs = {}
+ inputs = np.stack(imgs, axis=0).astype('float32')
+ im_shape = []
+ for e in im_info:
+ im_shape.append(np.array((e['im_shape'])).astype('float32'))
+ im_shape = np.stack(im_shape, axis=0)
+ return inputs, im_shape
+
+ def infer(self, image_list, tl_points=None):
+ inputs = []
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ output_list = []
+ info_list = []
+ for img in batch_image_list:
+ output, info = self.preprocess(img)
+ output_list.append(output)
+ info_list.append(info)
+ inputs, im_shape = self.create_inputs(output_list, info_list)
+
+ # model inference
+ result = self.predictor.run(inputs)
+
+ # postprocess
+ res = self.postprocess(inputs, im_shape, result)
+ if tl_points:
+ res = self.translate_to_ori_images(
+ res, tl_points[start_index:end_index])
+ results.append(res)
+ return results
+
+ def translate_to_ori_images(self, results, tl_points):
+ keypoints = []
+ for kpts, tl_pt in zip(results[self.output_keys[0]], tl_points):
+ kpts_np = np.array(kpts)
+ kpts_np[:, 0] += tl_pt[0]
+ kpts_np[:, 1] += tl_pt[1]
+ keypoints.append(kpts_np.tolist())
+ results[self.output_keys[0]] = keypoints
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ # for the input_keys as list
+ # inputs = [pipe_input[key] for pipe_input in pipe_inputs for key in self.input_keys]
+
+ # step1: for the input_keys as str
+ if len(self.input_keys) > 1:
+ tl_points = [input[self.input_keys[1]] for input in inputs]
+ tl_points = reduce(lambda x, y: x.extend(y) or x, tl_points)
+ else:
+ tl_points = None
+ key = self.input_keys[0]
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ else:
+ inputs = [[input[key]] for input in inputs]
+ sub_index_list = [len(input) for input in inputs]
+ inputs = reduce(lambda x, y: x.extend(y) or x, inputs)
+
+ # step2: run
+ outputs = self.infer(inputs, tl_points)
+
+ # step3: merge
+ curr_offsef_id = 0
+ pipe_outputs = []
+ for idx in range(len(sub_index_list)):
+ sub_start_idx = curr_offsef_id
+ sub_end_idx = curr_offsef_id + sub_index_list[idx]
+ output = outputs[sub_start_idx:sub_end_idx]
+ output = {k: [o[k] for o in output] for k in output[0]}
+ pipe_outputs.append(output)
+ curr_offsef_id = sub_end_idx
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/keypoint/postprocess.py b/paddlecv/ppcv/ops/models/keypoint/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a3b44f09cc395042d0e701f449fdd26655f2a0e
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/keypoint/postprocess.py
@@ -0,0 +1,208 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from scipy.optimize import linear_sum_assignment
+from collections import abc, defaultdict
+import cv2
+import numpy as np
+import math
+import paddle
+import paddle.nn as nn
+from .preprocess import get_affine_transform
+
+
+def warp_affine_joints(joints, mat):
+ """Apply affine transformation defined by the transform matrix on the
+ joints.
+
+ Args:
+ joints (np.ndarray[..., 2]): Origin coordinate of joints.
+ mat (np.ndarray[3, 2]): The affine matrix.
+
+ Returns:
+ matrix (np.ndarray[..., 2]): Result coordinate of joints.
+ """
+ joints = np.array(joints)
+ shape = joints.shape
+ joints = joints.reshape(-1, 2)
+ return np.dot(np.concatenate(
+ (joints, joints[:, 0:1] * 0 + 1), axis=1),
+ mat.T).reshape(shape)
+
+
+class HRNetPostProcess(object):
+ def __init__(self, use_dark=True):
+ self.use_dark = use_dark
+
+ def flip_back(self, output_flipped, matched_parts):
+ assert output_flipped.ndim == 4,\
+ 'output_flipped should be [batch_size, num_joints, height, width]'
+
+ output_flipped = output_flipped[:, :, :, ::-1]
+
+ for pair in matched_parts:
+ tmp = output_flipped[:, pair[0], :, :].copy()
+ output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
+ output_flipped[:, pair[1], :, :] = tmp
+
+ return output_flipped
+
+ def get_max_preds(self, heatmaps):
+ """get predictions from score maps
+
+ Args:
+ heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
+
+ Returns:
+ preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
+ maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
+ """
+ assert isinstance(heatmaps,
+ np.ndarray), 'heatmaps should be numpy.ndarray'
+ assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
+
+ batch_size = heatmaps.shape[0]
+ num_joints = heatmaps.shape[1]
+ width = heatmaps.shape[3]
+ heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
+ idx = np.argmax(heatmaps_reshaped, 2)
+ maxvals = np.amax(heatmaps_reshaped, 2)
+
+ maxvals = maxvals.reshape((batch_size, num_joints, 1))
+ idx = idx.reshape((batch_size, num_joints, 1))
+
+ preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
+
+ preds[:, :, 0] = (preds[:, :, 0]) % width
+ preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
+
+ pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
+ pred_mask = pred_mask.astype(np.float32)
+
+ preds *= pred_mask
+
+ return preds, maxvals
+
+ def gaussian_blur(self, heatmap, kernel):
+ border = (kernel - 1) // 2
+ batch_size = heatmap.shape[0]
+ num_joints = heatmap.shape[1]
+ height = heatmap.shape[2]
+ width = heatmap.shape[3]
+ for i in range(batch_size):
+ for j in range(num_joints):
+ origin_max = np.max(heatmap[i, j])
+ dr = np.zeros((height + 2 * border, width + 2 * border))
+ dr[border:-border, border:-border] = heatmap[i, j].copy()
+ dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
+ heatmap[i, j] = dr[border:-border, border:-border].copy()
+ heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
+ return heatmap
+
+ def dark_parse(self, hm, coord):
+ heatmap_height = hm.shape[0]
+ heatmap_width = hm.shape[1]
+ px = int(coord[0])
+ py = int(coord[1])
+ if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
+ dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
+ dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
+ dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
+ dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \
+ + hm[py-1][px-1])
+ dyy = 0.25 * (
+ hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
+ derivative = np.matrix([[dx], [dy]])
+ hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
+ if dxx * dyy - dxy**2 != 0:
+ hessianinv = hessian.I
+ offset = -hessianinv * derivative
+ offset = np.squeeze(np.array(offset.T), axis=0)
+ coord += offset
+ return coord
+
+ def dark_postprocess(self, hm, coords, kernelsize):
+ """
+ refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
+
+ """
+ hm = self.gaussian_blur(hm, kernelsize)
+ hm = np.maximum(hm, 1e-10)
+ hm = np.log(hm)
+ for n in range(coords.shape[0]):
+ for p in range(coords.shape[1]):
+ coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
+ return coords
+
+ def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
+ """the highest heatvalue location with a quarter offset in the
+ direction from the highest response to the second highest response.
+
+ Args:
+ heatmaps (numpy.ndarray): The predicted heatmaps
+ center (numpy.ndarray): The boxes center
+ scale (numpy.ndarray): The scale factor
+
+ Returns:
+ preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
+ maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
+ """
+
+ coords, maxvals = self.get_max_preds(heatmaps)
+
+ heatmap_height = heatmaps.shape[2]
+ heatmap_width = heatmaps.shape[3]
+
+ if self.use_dark:
+ coords = self.dark_postprocess(heatmaps, coords, kernelsize)
+ else:
+ for n in range(coords.shape[0]):
+ for p in range(coords.shape[1]):
+ hm = heatmaps[n][p]
+ px = int(math.floor(coords[n][p][0] + 0.5))
+ py = int(math.floor(coords[n][p][1] + 0.5))
+ if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
+ diff = np.array([
+ hm[py][px + 1] - hm[py][px - 1],
+ hm[py + 1][px] - hm[py - 1][px]
+ ])
+ coords[n][p] += np.sign(diff) * .25
+ preds = coords.copy()
+
+ # Transform back
+ for i in range(coords.shape[0]):
+ preds[i] = transform_preds(coords[i], center[i], scale[i],
+ [heatmap_width, heatmap_height])
+
+ return preds, maxvals
+
+ def __call__(self, output, center, scale, output_keys):
+ preds, maxvals = self.get_final_preds(output, center, scale)
+ keypoints = np.concatenate((preds, maxvals), axis=-1).tolist()
+ kpt_scores = np.mean(maxvals, axis=1).tolist()
+ return {output_keys[0]: keypoints, output_keys[1]: kpt_scores}
+
+
+def transform_preds(coords, center, scale, output_size):
+ target_coords = np.zeros(coords.shape)
+ trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1)
+ for p in range(coords.shape[0]):
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+ return target_coords
+
+
+def affine_transform(pt, t):
+ new_pt = np.array([pt[0], pt[1], 1.]).T
+ new_pt = np.dot(t, new_pt)
+ return new_pt[:2]
diff --git a/paddlecv/ppcv/ops/models/keypoint/preprocess.py b/paddlecv/ppcv/ops/models/keypoint/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec93a552cab6a84114fd59dd903cd925a7f24518
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/keypoint/preprocess.py
@@ -0,0 +1,202 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
+"""
+import cv2
+import numpy as np
+
+
+def get_affine_transform(center,
+ input_size,
+ rot,
+ output_size,
+ shift=(0., 0.),
+ inv=False):
+ """Get the affine transform matrix, given the center/scale/rot/output_size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ]): Size of the destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: The transform matrix.
+ """
+ assert len(center) == 2
+ assert len(output_size) == 2
+ assert len(shift) == 2
+ if not isinstance(input_size, (np.ndarray, list)):
+ input_size = np.array([input_size, input_size], dtype=np.float32)
+ scale_tmp = input_size
+
+ shift = np.array(shift)
+ src_w = scale_tmp[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = rotate_point([0., src_w * -0.5], rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift
+ src[1, :] = center + src_dir + scale_tmp * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def rotate_point(pt, angle_rad):
+ """Rotate a point by an angle.
+
+ Args:
+ pt (list[float]): 2 dimensional point to be rotated
+ angle_rad (float): rotation angle by radian
+
+ Returns:
+ list[float]: Rotated point.
+ """
+ assert len(pt) == 2
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ new_x = pt[0] * cs - pt[1] * sn
+ new_y = pt[0] * sn + pt[1] * cs
+ rotated_pt = [new_x, new_y]
+
+ return rotated_pt
+
+
+def _get_3rd_point(a, b):
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): point(x,y)
+ b (np.ndarray): point(x,y)
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ assert len(a) == 2
+ assert len(b) == 2
+ direction = a - b
+ third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
+
+ return third_pt
+
+
+class TopDownEvalAffine(object):
+ """apply affine transform to image and coords
+
+ Args:
+ trainsize (list): [w, h], the standard size used to train
+ records(dict): the dict contained the image and coords
+
+ Returns:
+ records (dict): contain the image and coords after tranformed
+
+ """
+
+ def __init__(self, trainsize):
+ self.trainsize = trainsize
+
+ def __call__(self, image, im_info):
+ rot = 0
+ imshape = im_info['im_shape'][::-1]
+ center = im_info['center'] if 'center' in im_info else imshape / 2.
+ scale = im_info['scale'] if 'scale' in im_info else imshape
+ trans = get_affine_transform(center, scale, rot, self.trainsize)
+ image = cv2.warpAffine(
+ image,
+ trans, (int(self.trainsize[0]), int(self.trainsize[1])),
+ flags=cv2.INTER_LINEAR)
+
+ return image, im_info
+
+
+class NormalizeImage(object):
+ """normalize image
+ Args:
+ mean (list): im - mean
+ std (list): im / std
+ is_scale (bool): whether need im / 255
+ is_channel_first (bool): if True: image shape is CHW, else: HWC
+ """
+
+ def __init__(self, mean, std, is_scale=True):
+ self.mean = mean
+ self.std = std
+ self.is_scale = is_scale
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.astype(np.float32, copy=False)
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+
+ if self.is_scale:
+ im = im / 255.0
+
+ im -= mean
+ im /= std
+ return im, im_info
+
+
+class Permute(object):
+ """permute image
+ Args:
+ to_bgr (bool): whether convert RGB to BGR
+ channel_first (bool): whether convert HWC to CHW
+ """
+
+ def __init__(self, ):
+ super().__init__()
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.transpose((2, 0, 1)).copy()
+ return im, im_info
diff --git a/paddlecv/ppcv/ops/models/ocr/__init__.py b/paddlecv/ppcv/ops/models/ocr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..acb84341c2f526d0dec86229261ca8b97dc030ca
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import ocr_db_detection
+from . import ocr_crnn_recognition
+from . import ocr_table_recognition
+from . import ocr_kie
+
+from .ocr_db_detection import *
+from .ocr_crnn_recognition import *
+from .ocr_table_recognition import *
+from .ocr_kie import *
+
+__all__ = ocr_db_detection.__all__
+__all__ += ocr_crnn_recognition.__all__
+__all__ += ocr_table_recognition.__all__
+__all__ += ocr_kie.__all__
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/__init__.py b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaf5fab95f7180f4bc2e08fb5884f4456ed39c50
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import OcrCrnnRecOp
+
+__all__ = ['OcrCrnnRecOp']
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/inference.py b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..43114ba7e63e348eb982ba8e5b6780b54b48a0a9
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/inference.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import importlib
+import os
+import numpy as np
+import math
+import paddle
+from collections import defaultdict
+
+from ppcv.ops.models.base import ModelBaseOp
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+from ppcv.ops.models.ocr.ocr_db_detection.preprocess import RGB2BGR
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class OcrCrnnRecOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(OcrCrnnRecOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+ self.batch_size = model_cfg["batch_size"]
+ self.rec_image_shape = list(model_cfg["PreProcess"][-1].values())[0][
+ "rec_image_shape"]
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["rec_text", "rec_score"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, result):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs, self.output_keys)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def infer(self, image_list):
+ width_list = [float(img.shape[1]) / img.shape[0] for img in image_list]
+ indices = np.argsort(np.array(width_list))
+
+ inputs = []
+ results = [None] * len(image_list)
+ for beg_img_no in range(0, len(image_list), self.batch_size):
+ end_img_no = min(len(image_list), beg_img_no + self.batch_size)
+ imgC, imgH, imgW = self.rec_image_shape
+ max_wh_ratio = imgW / imgH
+
+ norm_img_batch = []
+ for ino in range(beg_img_no, end_img_no):
+ h, w = image_list[indices[ino]].shape[0:2]
+ wh_ratio = w * 1.0 / h
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
+
+ for ino in range(beg_img_no, end_img_no):
+ norm_img = self.preprocess({
+ 'image': image_list[indices[ino]],
+ 'max_wh_ratio': max_wh_ratio
+ })['image']
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+
+ norm_img_batch = np.concatenate(norm_img_batch, axis=0)
+
+ # model inference
+ result = self.predictor.run(norm_img_batch)
+ # postprocess
+ result = self.postprocess(result)
+
+ for rno in range(len(result)):
+ results[indices[beg_img_no + rno]] = result[rno]
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ key = self.input_keys[0]
+ is_list = False
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ is_list = True
+ else:
+ inputs = [[input[key]] for input in inputs]
+ # expand a dim to adjust [[image,iamge],[image,image]] format
+ expand_dim = False
+ if isinstance(inputs[0][0], np.ndarray):
+ inputs = [inputs]
+ expand_dim = True
+ pipe_outputs = []
+ for i, images in enumerate(inputs):
+ sub_index_list = [len(input) for input in images]
+ images = reduce(lambda x, y: x.extend(y) or x, images)
+
+ # step2: run
+ outputs = self.infer(images)
+ # step3: merge
+ curr_offsef_id = 0
+ results = []
+ for idx in range(len(sub_index_list)):
+ sub_start_idx = curr_offsef_id
+ sub_end_idx = curr_offsef_id + sub_index_list[idx]
+ output = outputs[sub_start_idx:sub_end_idx]
+ if len(output) > 0:
+ output = {k: [o[k] for o in output] for k in output[0]}
+ if is_list is not True:
+ output = {k: output[k][0] for k in output}
+ else:
+ output = {self.output_keys[0]: [], self.output_keys[1]: []}
+ results.append(output)
+
+ curr_offsef_id = sub_end_idx
+ pipe_outputs.append(results)
+ if expand_dim:
+ pipe_outputs = pipe_outputs[0]
+ else:
+ outputs = []
+ for pipe_output in pipe_outputs:
+ d = defaultdict(list)
+ for item in pipe_output:
+ for k in self.output_keys:
+ d[k].append(item[k])
+ outputs.append(d)
+ pipe_outputs = outputs
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/postprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d13e7c706218a3769e5ec03d753804ecf539d1
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/postprocess.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import re
+
+from ppcv.utils.download import get_dict_path
+
+
+class BaseRecLabelDecode(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False):
+ self.reverse = False
+ self.character_str = []
+
+ if character_dict_path is None:
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
+ dict_character = list(self.character_str)
+ else:
+ character_dict_path = get_dict_path(character_dict_path)
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ self.character_str.append(line)
+ if use_space_char:
+ self.character_str.append(" ")
+ dict_character = list(self.character_str)
+ if 'arabic' in character_dict_path:
+ self.reverse = True
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+
+ def pred_reverse(self, pred):
+ pred_re = []
+ c_current = ''
+ for c in pred:
+ if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
+ if c_current != '':
+ pred_re.append(c_current)
+ pred_re.append(c)
+ c_current = ''
+ else:
+ c_current += c
+ if c_current != '':
+ pred_re.append(c_current)
+
+ return ''.join(pred_re[::-1])
+
+ def add_special_char(self, dict_character):
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+ if is_remove_duplicate:
+ selection[1:] = text_index[batch_idx][1:] != text_index[
+ batch_idx][:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index[batch_idx] != ignored_token
+
+ char_list = [
+ self.character[text_id]
+ for text_id in text_index[batch_idx][selection]
+ ]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx][selection]
+ else:
+ conf_list = [1] * len(selection)
+ if len(conf_list) == 0:
+ conf_list = [0]
+
+ text = ''.join(char_list)
+
+ if self.reverse: # for arabic rec
+ text = self.pred_reverse(text)
+
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def get_ignored_tokens(self):
+ return [0] # for ctc blank
+
+
+class CTCLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=True,
+ **kwargs):
+ super().__init__(character_dict_path, use_space_char)
+
+ def __call__(self, preds, output_keys, *args, **kwargs):
+ if isinstance(preds, tuple) or isinstance(preds, list):
+ preds = preds[-1]
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
+ result = [{
+ output_keys[0]: t[0],
+ output_keys[1]: t[1],
+ } for t in text]
+ return result
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank'] + dict_character
+ return dict_character
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/preprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..1145096d1454ca9920b60ac0cc26f7d4001e1870
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_crnn_recognition/preprocess.py
@@ -0,0 +1,58 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import math
+import cv2
+import numpy as np
+from PIL import Image
+
+
+class ReisizeNormImg(object):
+ def __init__(self, rec_image_shape=[3, 48, 320]):
+ super().__init__()
+ self.rec_image_shape = rec_image_shape
+
+ def resize_norm_img(self, img, max_wh_ratio):
+ imgC, imgH, imgW = self.rec_image_shape
+ assert imgC == img.shape[2]
+ imgW = int((imgH * max_wh_ratio))
+
+ h, w = img.shape[:2]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ return padding_im
+
+ def __call__(self, data):
+ """
+ """
+ data['image'] = self.resize_norm_img(data['image'],
+ data['max_wh_ratio'])
+ return data
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/__init__.py b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2beef384eebc482f814a301f0c423be036cf25d1
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import OcrDbDetOp
+
+__all__ = ['OcrDbDetOp']
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/inference.py b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..3205292a036db584d1bbe6f788959e1da4895018
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/inference.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import importlib
+import os
+import numpy as np
+import math
+import paddle
+from ppcv.ops.models.base import ModelBaseOp
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class OcrDbDetOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(OcrDbDetOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+ self.batch_size = 1
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["dt_polys", "dt_scores"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, result, shape_list):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs, shape_list, self.output_keys)
+ else:
+ outputs = ops(outputs, shape_list)
+ return outputs
+
+ def infer(self, image_list):
+ inputs = []
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ inputs, shape_list = self.preprocess({
+ "image": batch_image_list[0]
+ })
+ shape_list = np.expand_dims(shape_list, axis=0)
+ # model inference
+ result = self.predictor.run(inputs)[0]
+ # postprocess
+ result = self.postprocess(result, shape_list)
+ results.append(result)
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ key = self.input_keys[0]
+ is_list = False
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ is_list = True
+ else:
+ inputs = [[input[key]] for input in inputs]
+ sub_index_list = [len(input) for input in inputs]
+ inputs = reduce(lambda x, y: x.extend(y) or x, inputs)
+
+ # step2: run
+ outputs = self.infer(inputs)
+
+ # step3: merge
+ curr_offsef_id = 0
+ pipe_outputs = []
+ for idx in range(len(sub_index_list)):
+ sub_start_idx = curr_offsef_id
+ sub_end_idx = curr_offsef_id + sub_index_list[idx]
+ output = outputs[sub_start_idx:sub_end_idx]
+ output = {k: [o[k] for o in output] for k in output[0]}
+ if is_list is not True:
+ output = {k: output[k][0] for k in output}
+
+ pipe_outputs.append(output)
+
+ curr_offsef_id = sub_end_idx
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/postprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee1a7395c2404b1f8e07ec221f00be74eb3818b0
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/postprocess.py
@@ -0,0 +1,245 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import cv2
+import numpy as np
+from shapely.geometry import Polygon
+import pyclipper
+
+
+class DBPostProcess(object):
+ """
+ The post process for Differentiable Binarization (DB).
+ """
+
+ def __init__(self,
+ thresh=0.3,
+ box_thresh=0.7,
+ max_candidates=1000,
+ unclip_ratio=2.0,
+ use_dilation=False,
+ score_mode="fast",
+ box_type='quad',
+ **kwargs):
+ self.thresh = thresh
+ self.box_thresh = box_thresh
+ self.max_candidates = max_candidates
+ self.unclip_ratio = unclip_ratio
+ self.min_size = 3
+ self.score_mode = score_mode
+ self.box_type = box_type
+ assert score_mode in [
+ "slow", "fast"
+ ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
+
+ self.dilation_kernel = None if not use_dilation else np.array(
+ [[1, 1], [1, 1]])
+
+ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+ '''
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ '''
+
+ bitmap = _bitmap
+ height, width = bitmap.shape
+
+ boxes = []
+ scores = []
+
+ contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
+ cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours[:self.max_candidates]:
+ epsilon = 0.002 * cv2.arcLength(contour, True)
+ approx = cv2.approxPolyDP(contour, epsilon, True)
+ points = approx.reshape((-1, 2))
+ if points.shape[0] < 4:
+ continue
+
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
+ if self.box_thresh > score:
+ continue
+
+ if points.shape[0] > 2:
+ box = self.unclip(points, self.unclip_ratio)
+ if len(box) > 1:
+ continue
+ else:
+ continue
+ box = box.reshape(-1, 2)
+
+ _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+ if sside < self.min_size + 2:
+ continue
+
+ box = np.array(box)
+ box[:, 0] = np.clip(
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 1] = np.clip(
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ boxes.append(box.tolist())
+ scores.append(score)
+ return boxes, scores
+
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+ '''
+ _bitmap: single map with shape (1, H, W),
+ whose values are binarized as {0, 1}
+ '''
+
+ bitmap = _bitmap
+ height, width = bitmap.shape
+
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
+ cv2.CHAIN_APPROX_SIMPLE)
+ if len(outs) == 3:
+ img, contours, _ = outs[0], outs[1], outs[2]
+ elif len(outs) == 2:
+ contours, _ = outs[0], outs[1]
+
+ num_contours = min(len(contours), self.max_candidates)
+
+ boxes = []
+ scores = []
+ for index in range(num_contours):
+ contour = contours[index]
+ points, sside = self.get_mini_boxes(contour)
+ if sside < self.min_size:
+ continue
+ points = np.array(points)
+ if self.score_mode == "fast":
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
+ else:
+ score = self.box_score_slow(pred, contour)
+ if self.box_thresh > score:
+ continue
+
+ box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
+ box, sside = self.get_mini_boxes(box)
+ if sside < self.min_size + 2:
+ continue
+ box = np.array(box)
+
+ box[:, 0] = np.clip(
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 1] = np.clip(
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ boxes.append(box.astype(np.int16))
+ scores.append(score)
+ return np.array(boxes, dtype=np.int16), scores
+
+ def unclip(self, box, unclip_ratio):
+ poly = Polygon(box)
+ distance = poly.area * unclip_ratio / poly.length
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ expanded = np.array(offset.Execute(distance))
+ return expanded
+
+ def get_mini_boxes(self, contour):
+ bounding_box = cv2.minAreaRect(contour)
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
+ if points[1][1] > points[0][1]:
+ index_1 = 0
+ index_4 = 1
+ else:
+ index_1 = 1
+ index_4 = 0
+ if points[3][1] > points[2][1]:
+ index_2 = 2
+ index_3 = 3
+ else:
+ index_2 = 3
+ index_3 = 2
+
+ box = [
+ points[index_1], points[index_2], points[index_3], points[index_4]
+ ]
+ return box, min(bounding_box[1])
+
+ def box_score_fast(self, bitmap, _box):
+ '''
+ box_score_fast: use bbox mean score as the mean score
+ '''
+ h, w = bitmap.shape[:2]
+ box = _box.copy()
+ xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
+ ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+ box[:, 0] = box[:, 0] - xmin
+ box[:, 1] = box[:, 1] - ymin
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+
+ def box_score_slow(self, bitmap, contour):
+ '''
+ box_score_slow: use polyon mean score as the mean score
+ '''
+ h, w = bitmap.shape[:2]
+ contour = contour.copy()
+ contour = np.reshape(contour, (-1, 2))
+
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+
+ contour[:, 0] = contour[:, 0] - xmin
+ contour[:, 1] = contour[:, 1] - ymin
+
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+
+ def __call__(self, pred, shape_list, output_keys):
+ pred = pred[:, 0, :, :]
+ segmentation = pred > self.thresh
+
+ boxes_batch = []
+ for batch_index in range(pred.shape[0]):
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
+ if self.dilation_kernel is not None:
+ mask = cv2.dilate(
+ np.array(segmentation[batch_index]).astype(np.uint8),
+ self.dilation_kernel)
+ else:
+ mask = segmentation[batch_index]
+ if self.box_type == 'poly':
+ boxes, scores = self.polygons_from_bitmap(pred[batch_index],
+ mask, src_w, src_h)
+ elif self.box_type == 'quad':
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
+ src_w, src_h)
+ else:
+ raise ValueError(
+ "box_type can only be one of ['quad', 'poly']")
+
+ boxes_batch.append({
+ output_keys[0]: boxes,
+ output_keys[1]: scores,
+ })
+
+ return boxes_batch[0]
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/preprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b78fd32856ff39c8aa4e97a0cc7481010c8520db
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_db_detection/preprocess.py
@@ -0,0 +1,224 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import cv2
+import numpy as np
+from PIL import Image
+
+
+class DetResizeForTest(object):
+ def __init__(self, **kwargs):
+ super(DetResizeForTest, self).__init__()
+ self.resize_type = 0
+ self.keep_ratio = False
+ if 'image_shape' in kwargs:
+ self.image_shape = kwargs['image_shape']
+ self.resize_type = 1
+ if 'keep_ratio' in kwargs:
+ self.keep_ratio = kwargs['keep_ratio']
+ elif 'limit_side_len' in kwargs:
+ self.limit_side_len = kwargs['limit_side_len']
+ self.limit_type = kwargs.get('limit_type', 'min')
+ elif 'resize_long' in kwargs:
+ self.resize_type = 2
+ self.resize_long = kwargs.get('resize_long', 960)
+ else:
+ self.limit_side_len = 736
+ self.limit_type = 'min'
+ self.resize_type = 1
+
+ def __call__(self, data):
+ img = data['image']
+ src_h, src_w, _ = img.shape
+ if sum([src_h, src_w]) < 64:
+ img = self.image_padding(img)
+
+ func = eval(f"self.resize_image_type{self.resize_type}")
+ img, [ratio_h, ratio_w] = func(img)
+ # if self.resize_type == 0:
+ # img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+ # elif self.resize_type == 2:
+ # img, [ratio_h, ratio_w] = self.resize_image_type2(img)
+ # else:
+ # img, [ratio_h, ratio_w] = self.resize_image_type1(img)
+ data['image'] = img
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ return data
+
+ def image_padding(self, im, value=0):
+ h, w, c = im.shape
+ im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
+ im_pad[:h, :w, :] = im
+ return im_pad
+
+ def resize_image_type1(self, img):
+ resize_h, resize_w = self.image_shape
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
+ if self.keep_ratio is True:
+ resize_w = ori_w * resize_h / ori_h
+ N = math.ceil(resize_w / 32)
+ resize_w = N * 32
+ ratio_h = float(resize_h) / ori_h
+ ratio_w = float(resize_w) / ori_w
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ # return img, np.array([ori_h, ori_w])
+ return img, [ratio_h, ratio_w]
+
+ def resize_image_type0(self, img):
+ """
+ resize image to a size multiple of 32 which is required by the network
+ args:
+ img(array): array with shape [h, w, c]
+ return(tuple):
+ img, (ratio_h, ratio_w)
+ """
+ limit_side_len = self.limit_side_len
+ h, w, c = img.shape
+
+ # limit the max side
+ if self.limit_type == 'max':
+ if max(h, w) > limit_side_len:
+ if h > w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ else:
+ ratio = 1.
+ elif self.limit_type == 'min':
+ if min(h, w) < limit_side_len:
+ if h < w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ else:
+ ratio = 1.
+ elif self.limit_type == 'resize_long':
+ ratio = float(limit_side_len) / max(h, w)
+ else:
+ raise Exception('not support limit type, image ')
+ resize_h = int(h * ratio)
+ resize_w = int(w * ratio)
+
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+ try:
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
+ return None, (None, None)
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ except:
+ print(img.shape, resize_w, resize_h)
+ sys.exit(0)
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return img, [ratio_h, ratio_w]
+
+ def resize_image_type2(self, img):
+ h, w, _ = img.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h > resize_w:
+ ratio = float(self.resize_long) / resize_h
+ else:
+ ratio = float(self.resize_long) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return img, [ratio_h, ratio_w]
+
+
+class NormalizeImage(object):
+ """ normalize image such as substract mean, divide std
+ """
+
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype('float32')
+ self.std = np.array(std).reshape(shape).astype('float32')
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ assert isinstance(img,
+ np.ndarray), "invalid input 'img' in NormalizeImage"
+ data['image'] = (
+ img.astype('float32') * self.scale - self.mean) / self.std
+ return data
+
+
+class ToCHWImage(object):
+ """ convert hwc image to chw image
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ img = data["image"]
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ img = img.transpose((2, 0, 1))
+ data["image"] = img
+ return data
+
+
+class KeepKeys(object):
+ def __init__(self, keep_keys, **kwargs):
+ self.keep_keys = keep_keys
+
+ def __call__(self, data):
+ data_list = []
+ for key in self.keep_keys:
+ data_list.append(data[key])
+ return data_list
+
+
+class ExpandDim(object):
+ def __init__(self, axis=0):
+ self.axis = axis
+
+ def __call__(self, data):
+ data["image"] = np.expand_dims(data["image"], axis=self.axis)
+ return data
+
+
+class RGB2BGR(object):
+ def __call__(self, data):
+ data["image"] = data["image"][:, :, ::-1]
+ return data
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_kie/__init__.py b/paddlecv/ppcv/ops/models/ocr/ocr_kie/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..144d2ad1466705b86c85008ed4ce25c9507608bd
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_kie/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .inference import PPStructureKieSerOp, PPStructureKieReOp
+
+__all__ = ['PPStructureKieSerOp', 'PPStructureKieReOp']
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_kie/inference.py b/paddlecv/ppcv/ops/models/ocr/ocr_kie/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..27079cd376fef6447da65afa711790fb6cac59ab
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_kie/inference.py
@@ -0,0 +1,208 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import importlib
+import math
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+from ppcv.ops.models.base import ModelBaseOp
+
+from ppcv.ops.models.classification.preprocess import ResizeImage
+from ppcv.ops.models.ocr.ocr_db_detection.preprocess import NormalizeImage, ToCHWImage, KeepKeys
+from ppcv.ops.models.ocr.ocr_kie.preprocess import *
+from ppcv.ops.models.ocr.ocr_kie.postprocess import *
+
+
+@register
+class PPStructureKieSerOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(PPStructureKieSerOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+ self.batch_size = model_cfg["batch_size"]
+ self.use_visual_backbone = model_cfg.get('use_visual_backbone', False)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["pred_id", "pred", "dt_polys", "rec_text", "inputs"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, result, segment_offset_ids, ocr_infos):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs,
+ segment_offset_ids=segment_offset_ids,
+ ocr_infos=ocr_infos)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def infer(self, data_list):
+ batch_loop_cnt = math.ceil(float(len(data_list)) / self.batch_size)
+ results = []
+ ser_inputs = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(data_list))
+ batch_data_list = data_list[start_index:end_index]
+ # preprocess
+ inputs = [
+ self.preprocess({
+ 'image': data[self.input_keys[0]],
+ 'ocr': {
+ 'dt_polys': data[self.input_keys[1]],
+ 'rec_text': data[self.input_keys[2]]
+ }
+ }) for data in batch_data_list
+ ]
+ ser_inputs.extend(inputs)
+ # concat to batch
+ model_inputs = []
+ for i in range(len(inputs[0])):
+ x = [x[i] for x in inputs]
+ if isinstance(x[0], np.ndarray):
+ x = np.stack(x)
+ model_inputs.append(x)
+ # model inference
+ if self.use_visual_backbone:
+ result = self.predictor.run(model_inputs[:5])
+ else:
+ result = self.predictor.run(model_inputs[:4])
+ # postprocess
+ result = self.postprocess(
+ result[0],
+ segment_offset_ids=model_inputs[6],
+ ocr_infos=model_inputs[7])
+ results.extend(result)
+ return results, ser_inputs
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ # step2: run
+ outputs, ser_inputs = self.infer(inputs)
+ # step3: merge
+ pipe_outputs = []
+ for output, ser_input in zip(outputs, ser_inputs):
+ d = defaultdict(list)
+ for res in output:
+ d[self.output_keys[0]].append(res['pred_id'])
+ d[self.output_keys[1]].append(res['pred'])
+ d[self.output_keys[2]].append(res['points'])
+ d[self.output_keys[3]].append(res['transcription'])
+ d[self.output_keys[4]] = ser_input
+ pipe_outputs.append(d)
+ return pipe_outputs
+
+
+@register
+class PPStructureKieReOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(PPStructureKieReOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+ self.batch_size = model_cfg["batch_size"]
+ self.use_visual_backbone = model_cfg.get('use_visual_backbone', False)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["head", "tail"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, result, **kwargs):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs, **kwargs)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def infer(self, data_list):
+ batch_loop_cnt = math.ceil(float(len(data_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(data_list))
+ batch_data_list = data_list[start_index:end_index]
+ # preprocess
+ inputs = [
+ self.preprocess({
+ 'ser_inputs': data['ser.inputs'],
+ 'ser_preds': data['ser.pred']
+ }) for data in batch_data_list
+ ]
+ # concat to batch
+ model_inputs = []
+ for i in range(len(inputs[0])):
+ x = [x[i] for x in inputs]
+ if isinstance(x[0], np.ndarray):
+ x = np.stack(x)
+ model_inputs.append(x)
+ # model inference
+ if not self.use_visual_backbone:
+ model_inputs.pop(4)
+
+ result = self.predictor.run(model_inputs[:-1])
+
+ preds = dict(
+ loss=result[1],
+ pred_relations=result[2],
+ hidden_states=result[0])
+
+ # postprocess
+ result = self.postprocess(
+ preds,
+ ser_results=batch_data_list,
+ entity_idx_dict_batch=model_inputs[-1])
+ results.extend(result)
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ # step2: run
+ outputs = self.infer(inputs)
+ # step3: merge
+ pipe_outputs = []
+ for output in outputs:
+ d = defaultdict(list)
+ for res in output:
+ d[self.output_keys[0]].append(res[0])
+ d[self.output_keys[1]].append(res[1])
+ pipe_outputs.append(d)
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_kie/postprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_kie/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a4f9844f28d9bdd72f6f5dc6c26cbb0799a71ea
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_kie/postprocess.py
@@ -0,0 +1,178 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import re
+
+from ppcv.utils.download import get_dict_path
+
+import numpy as np
+import paddle
+
+from .preprocess import load_vqa_bio_label_maps
+
+
+class VQASerTokenLayoutLMPostProcess(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, class_path, **kwargs):
+ super(VQASerTokenLayoutLMPostProcess, self).__init__()
+ class_path = get_dict_path(class_path)
+ label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
+
+ self.label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ self.label2id_map_for_draw[key] = label2id_map[key]
+
+ self.id2label_map_for_show = dict()
+ for key in self.label2id_map_for_draw:
+ val = self.label2id_map_for_draw[key]
+ if key == "O":
+ self.id2label_map_for_show[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ self.id2label_map_for_show[val] = key[2:]
+ else:
+ self.id2label_map_for_show[val] = key
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, tuple):
+ preds = preds[0]
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+
+ if batch is not None:
+ return self._metric(preds, batch[5])
+ else:
+ return self._infer(preds, **kwargs)
+
+ def _metric(self, preds, label):
+ pred_idxs = preds.argmax(axis=2)
+ decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
+ label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
+
+ for i in range(pred_idxs.shape[0]):
+ for j in range(pred_idxs.shape[1]):
+ if label[i, j] != -100:
+ label_decode_out_list[i].append(self.id2label_map[label[
+ i, j]])
+ decode_out_list[i].append(self.id2label_map[pred_idxs[i,
+ j]])
+ return decode_out_list, label_decode_out_list
+
+ def _infer(self, preds, segment_offset_ids, ocr_infos):
+ results = []
+
+ for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
+ ocr_infos):
+ pred = np.argmax(pred, axis=1)
+ pred = [self.id2label_map[idx] for idx in pred]
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = pred[start_id:end_id]
+ curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = self.id2label_map_for_show[int(
+ pred_id)]
+ results.append(ocr_info)
+ return results
+
+
+class VQAReTokenLayoutLMPostProcess(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, **kwargs):
+ super(VQAReTokenLayoutLMPostProcess, self).__init__()
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ pred_relations = preds['pred_relations']
+ if isinstance(preds['pred_relations'], paddle.Tensor):
+ pred_relations = pred_relations.numpy()
+ pred_relations = self.decode_pred(pred_relations)
+
+ if label is not None:
+ return self._metric(pred_relations, label)
+ else:
+ return self._infer(pred_relations, *args, **kwargs)
+
+ def _metric(self, pred_relations, label):
+ return pred_relations, label[-1], label[-2]
+
+ def _infer(self, pred_relations, *args, **kwargs):
+ ser_results = kwargs['ser_results']
+ entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
+
+ # merge relations and ocr info
+ results = []
+ for pred_relation, ser_result, entity_idx_dict in zip(
+ pred_relations, ser_results, entity_idx_dict_batch):
+ result = []
+ used_tail_id = []
+ for relation in pred_relation:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ head_idx = entity_idx_dict[relation['head_id']]
+ ocr_info_head = {
+ 'dt_polys': ser_result['ser.dt_polys'][head_idx].tolist(),
+ 'rec_text': ser_result['ser.rec_text'][head_idx],
+ 'pred': ser_result['ser.pred'][head_idx],
+ }
+
+ tail_idx = entity_idx_dict[relation['tail_id']]
+ ocr_info_tail = {
+ 'dt_polys': ser_result['ser.dt_polys'][tail_idx].tolist(),
+ 'rec_text': ser_result['ser.rec_text'][tail_idx],
+ 'pred': ser_result['ser.pred'][tail_idx],
+ }
+ result.append((ocr_info_head, ocr_info_tail))
+ results.append(result)
+ return results
+
+ def decode_pred(self, pred_relations):
+ pred_relations_new = []
+ for pred_relation in pred_relations:
+ pred_relation_new = []
+ pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1]
+ for relation in pred_relation:
+ relation_new = dict()
+ relation_new['head_id'] = relation[0, 0]
+ relation_new['head'] = tuple(relation[1])
+ relation_new['head_type'] = relation[2, 0]
+ relation_new['tail_id'] = relation[3, 0]
+ relation_new['tail'] = tuple(relation[4])
+ relation_new['tail_type'] = relation[5, 0]
+ relation_new['type'] = relation[6, 0]
+ pred_relation_new.append(relation_new)
+ pred_relations_new.append(pred_relation_new)
+ return pred_relations_new
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_kie/preprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_kie/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..09897b0912529b4de44299daff648aa9563b6556
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_kie/preprocess.py
@@ -0,0 +1,598 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+import json
+import copy
+import cv2
+import paddle
+import numpy as np
+from collections import defaultdict
+from ppcv.utils.download import get_dict_path
+
+
+def load_vqa_bio_label_maps(label_map_path):
+ label_map_path = get_dict_path(label_map_path)
+ with open(label_map_path, "r", encoding='utf-8') as fin:
+ lines = fin.readlines()
+ old_lines = [line.strip() for line in lines]
+ lines = ["O"]
+ for line in old_lines:
+ # "O" has already been in lines
+ if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
+ continue
+ lines.append(line)
+ labels = ["O"]
+ for line in lines[1:]:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def order_by_tbyx(ocr_info):
+ res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
+ for i in range(len(res) - 1):
+ for j in range(i, 0, -1):
+ if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
+ (res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
+ tmp = copy.deepcopy(res[j])
+ res[j] = copy.deepcopy(res[j + 1])
+ res[j + 1] = copy.deepcopy(tmp)
+ else:
+ break
+ return res
+
+
+class VQATokenLabelEncode(object):
+ """
+ Label encode for NLP VQA methods
+ """
+
+ def __init__(self,
+ class_path,
+ contains_re=False,
+ add_special_ids=False,
+ algorithm='LayoutXLM',
+ use_textline_bbox_info=True,
+ order_method=None,
+ infer_mode=True,
+ **kwargs):
+ super(VQATokenLabelEncode, self).__init__()
+ from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
+ tokenizer_dict = {
+ 'LayoutXLM': {
+ 'class': LayoutXLMTokenizer,
+ 'pretrained_model': 'layoutxlm-base-uncased'
+ },
+ 'LayoutLM': {
+ 'class': LayoutLMTokenizer,
+ 'pretrained_model': 'layoutlm-base-uncased'
+ },
+ 'LayoutLMv2': {
+ 'class': LayoutLMv2Tokenizer,
+ 'pretrained_model': 'layoutlmv2-base-uncased'
+ }
+ }
+ self.contains_re = contains_re
+ tokenizer_config = tokenizer_dict[algorithm]
+ self.tokenizer = tokenizer_config['class'].from_pretrained(
+ tokenizer_config['pretrained_model'])
+ self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
+ self.add_special_ids = add_special_ids
+ self.infer_mode = infer_mode
+ self.use_textline_bbox_info = use_textline_bbox_info
+ self.order_method = order_method
+ assert self.order_method in [None, "tb-yx"]
+
+ def split_bbox(self, bbox, text, tokenizer):
+ words = text.split()
+ token_bboxes = []
+ curr_word_idx = 0
+ x1, y1, x2, y2 = bbox
+ unit_w = (x2 - x1) / len(text)
+ for idx, word in enumerate(words):
+ curr_w = len(word) * unit_w
+ word_bbox = [x1, y1, x1 + curr_w, y2]
+ token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
+ x1 += (len(word) + 1) * unit_w
+ return token_bboxes
+
+ def filter_empty_contents(self, ocr_info):
+ """
+ find out the empty texts and remove the links
+ """
+ new_ocr_info = []
+ empty_index = []
+ for idx, info in enumerate(ocr_info):
+ if len(info["transcription"]) > 0:
+ new_ocr_info.append(copy.deepcopy(info))
+ else:
+ empty_index.append(info["id"])
+
+ for idx, info in enumerate(new_ocr_info):
+ new_link = []
+ for link in info["linking"]:
+ if link[0] in empty_index or link[1] in empty_index:
+ continue
+ new_link.append(link)
+ new_ocr_info[idx]["linking"] = new_link
+ return new_ocr_info
+
+ def __call__(self, data):
+ # load bbox and label info
+ ocr_info = self._load_ocr_info(data)
+
+ for idx in range(len(ocr_info)):
+ if "bbox" not in ocr_info[idx]:
+ ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
+ "points"])
+
+ if self.order_method == "tb-yx":
+ ocr_info = order_by_tbyx(ocr_info)
+
+ # for re
+ train_re = self.contains_re and not self.infer_mode
+ if train_re:
+ ocr_info = self.filter_empty_contents(ocr_info)
+
+ height, width, _ = data['image'].shape
+
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ segment_offset_id = []
+ gt_label_list = []
+
+ entities = []
+
+ if train_re:
+ relations = []
+ id2label = {}
+ entity_id_to_index_map = {}
+ empty_entity = set()
+
+ data['ocr_info'] = copy.deepcopy(ocr_info)
+
+ for info in ocr_info:
+ text = info["transcription"]
+ if len(text) <= 0:
+ continue
+ if train_re:
+ # for re
+ if len(text) == 0:
+ empty_entity.add(info["id"])
+ continue
+ id2label[info["id"]] = info["label"]
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
+ # smooth_box
+ info["bbox"] = self.trans_poly_to_bbox(info["points"])
+
+ encode_res = self.tokenizer.encode(
+ text,
+ pad_to_max_seq_len=False,
+ return_attention_mask=True,
+ return_token_type_ids=True)
+
+ if not self.add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
+ -1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
+ -1]
+
+ if self.use_textline_bbox_info:
+ bbox = [info["bbox"]] * len(encode_res["input_ids"])
+ else:
+ bbox = self.split_bbox(info["bbox"], info["transcription"],
+ self.tokenizer)
+ if len(bbox) <= 0:
+ continue
+ bbox = self._smooth_box(bbox, height, width)
+ if self.add_special_ids:
+ bbox.insert(0, [0, 0, 0, 0])
+ bbox.append([0, 0, 0, 0])
+
+ # parse label
+ if not self.infer_mode:
+ label = info['label']
+ gt_label = self._parse_label(label, encode_res)
+
+ # construct entities for re
+ if train_re:
+ if gt_label[0] != self.label2id_map["O"]:
+ entity_id_to_index_map[info["id"]] = len(entities)
+ label = label.upper()
+ entities.append({
+ "start": len(input_ids_list),
+ "end":
+ len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ })
+ else:
+ entities.append({
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": 'O',
+ })
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend(bbox)
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+ if not self.infer_mode:
+ gt_label_list.extend(gt_label)
+
+ data['input_ids'] = input_ids_list
+ data['token_type_ids'] = token_type_ids_list
+ data['bbox'] = bbox_list
+ data['attention_mask'] = [1] * len(input_ids_list)
+ data['labels'] = gt_label_list
+ data['segment_offset_id'] = segment_offset_id
+ data['tokenizer_params'] = dict(
+ padding_side=self.tokenizer.padding_side,
+ pad_token_type_id=self.tokenizer.pad_token_type_id,
+ pad_token_id=self.tokenizer.pad_token_id)
+ data['entities'] = entities
+
+ if train_re:
+ data['relations'] = relations
+ data['id2label'] = id2label
+ data['empty_entity'] = empty_entity
+ data['entity_id_to_index_map'] = entity_id_to_index_map
+ return data
+
+ def trans_poly_to_bbox(self, poly):
+ x1 = int(np.min([p[0] for p in poly]))
+ x2 = int(np.max([p[0] for p in poly]))
+ y1 = int(np.min([p[1] for p in poly]))
+ y2 = int(np.max([p[1] for p in poly]))
+ return [x1, y1, x2, y2]
+
+ def _load_ocr_info(self, data):
+ if self.infer_mode:
+ ocr_result = data['ocr']
+ bboxes = ocr_result['dt_polys']
+ txts = ocr_result['rec_text']
+
+ ocr_info = []
+ for box, txt in zip(bboxes, txts):
+ ocr_info.append({
+ "transcription": txt,
+ "bbox": self.trans_poly_to_bbox(box),
+ "points": box,
+ })
+ return ocr_info
+ else:
+ info = data['label']
+ # read text info
+ info_dict = json.loads(info)
+ return info_dict
+
+ def _smooth_box(self, bboxes, height, width):
+ bboxes = np.array(bboxes)
+ bboxes[:, 0] = bboxes[:, 0] * 1000 / width
+ bboxes[:, 2] = bboxes[:, 2] * 1000 / width
+ bboxes[:, 1] = bboxes[:, 1] * 1000 / height
+ bboxes[:, 3] = bboxes[:, 3] * 1000 / height
+ bboxes = bboxes.astype("int64").tolist()
+ return bboxes
+
+ def _parse_label(self, label, encode_res):
+ gt_label = []
+ if label.lower() in ["other", "others", "ignore"]:
+ gt_label.extend([0] * len(encode_res["input_ids"]))
+ else:
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
+ (len(encode_res["input_ids"]) - 1))
+ return
+
+
+class VQATokenPad(object):
+ def __init__(self,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False,
+ infer_mode=True,
+ **kwargs):
+ self.max_seq_len = max_seq_len
+ self.pad_to_max_seq_len = max_seq_len
+ self.return_attention_mask = return_attention_mask
+ self.return_token_type_ids = return_token_type_ids
+ self.truncation_strategy = truncation_strategy
+ self.return_overflowing_tokens = return_overflowing_tokens
+ self.return_special_tokens_mask = return_special_tokens_mask
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ needs_to_be_padded = self.pad_to_max_seq_len and len(data[
+ "input_ids"]) < self.max_seq_len
+
+ if needs_to_be_padded:
+ if 'tokenizer_params' in data:
+ tokenizer_params = data.pop('tokenizer_params')
+ else:
+ tokenizer_params = dict(
+ padding_side='right', pad_token_type_id=0, pad_token_id=1)
+
+ difference = self.max_seq_len - len(data["input_ids"])
+ if tokenizer_params['padding_side'] == 'right':
+ if self.return_attention_mask:
+ data["attention_mask"] = [1] * len(data[
+ "input_ids"]) + [0] * difference
+ if self.return_token_type_ids:
+ data["token_type_ids"] = (
+ data["token_type_ids"] +
+ [tokenizer_params['pad_token_type_id']] * difference)
+ if self.return_special_tokens_mask:
+ data["special_tokens_mask"] = data[
+ "special_tokens_mask"] + [1] * difference
+ data["input_ids"] = data["input_ids"] + [
+ tokenizer_params['pad_token_id']
+ ] * difference
+ if not self.infer_mode:
+ data["labels"] = data[
+ "labels"] + [self.pad_token_label_id] * difference
+ data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
+ elif tokenizer_params['padding_side'] == 'left':
+ if self.return_attention_mask:
+ data["attention_mask"] = [0] * difference + [
+ 1
+ ] * len(data["input_ids"])
+ if self.return_token_type_ids:
+ data["token_type_ids"] = (
+ [tokenizer_params['pad_token_type_id']] * difference +
+ data["token_type_ids"])
+ if self.return_special_tokens_mask:
+ data["special_tokens_mask"] = [
+ 1
+ ] * difference + data["special_tokens_mask"]
+ data["input_ids"] = [tokenizer_params['pad_token_id']
+ ] * difference + data["input_ids"]
+ if not self.infer_mode:
+ data["labels"] = [self.pad_token_label_id
+ ] * difference + data["labels"]
+ data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
+ else:
+ if self.return_attention_mask:
+ data["attention_mask"] = [1] * len(data["input_ids"])
+
+ for key in data:
+ if key in [
+ 'input_ids', 'labels', 'token_type_ids', 'bbox',
+ 'attention_mask'
+ ]:
+ if self.infer_mode:
+ if key != 'labels':
+ length = min(len(data[key]), self.max_seq_len)
+ data[key] = data[key][:length]
+ else:
+ continue
+ data[key] = np.array(data[key], dtype='int64')
+ return data
+
+
+class VQASerTokenChunk(object):
+ def __init__(self, max_seq_len=512, infer_mode=True, **kwargs):
+ self.max_seq_len = max_seq_len
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ encoded_inputs_all = []
+ seq_len = len(data['input_ids'])
+ for index in range(0, seq_len, self.max_seq_len):
+ chunk_beg = index
+ chunk_end = min(index + self.max_seq_len, seq_len)
+ encoded_inputs_example = {}
+ for key in data:
+ if key in [
+ 'label', 'input_ids', 'labels', 'token_type_ids',
+ 'bbox', 'attention_mask'
+ ]:
+ if self.infer_mode and key == 'labels':
+ encoded_inputs_example[key] = data[key]
+ else:
+ encoded_inputs_example[key] = data[key][chunk_beg:
+ chunk_end]
+ else:
+ encoded_inputs_example[key] = data[key]
+
+ encoded_inputs_all.append(encoded_inputs_example)
+ if len(encoded_inputs_all) == 0:
+ return None
+ return encoded_inputs_all[0]
+
+
+class VQAReTokenChunk(object):
+ def __init__(self,
+ max_seq_len=512,
+ entities_labels=None,
+ infer_mode=True,
+ **kwargs):
+ self.max_seq_len = max_seq_len
+ self.entities_labels = {
+ 'HEADER': 0,
+ 'QUESTION': 1,
+ 'ANSWER': 2
+ } if entities_labels is None else entities_labels
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ # prepare data
+ entities = data.pop('entities')
+ relations = data.pop('relations')
+ encoded_inputs_all = []
+ for index in range(0, len(data["input_ids"]), self.max_seq_len):
+ item = {}
+ for key in data:
+ if key in [
+ 'label', 'input_ids', 'labels', 'token_type_ids',
+ 'bbox', 'attention_mask'
+ ]:
+ if self.infer_mode and key == 'labels':
+ item[key] = data[key]
+ else:
+ item[key] = data[key][index:index + self.max_seq_len]
+ else:
+ item[key] = data[key]
+ # select entity in current chunk
+ entities_in_this_span = []
+ global_to_local_map = {} #
+ for entity_id, entity in enumerate(entities):
+ if (index <= entity["start"] < index + self.max_seq_len and
+ index <= entity["end"] < index + self.max_seq_len):
+ entity["start"] = entity["start"] - index
+ entity["end"] = entity["end"] - index
+ global_to_local_map[entity_id] = len(entities_in_this_span)
+ entities_in_this_span.append(entity)
+
+ # select relations in current chunk
+ relations_in_this_span = []
+ for relation in relations:
+ if (index <= relation["start_index"] < index + self.max_seq_len
+ and index <= relation["end_index"] <
+ index + self.max_seq_len):
+ relations_in_this_span.append({
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ })
+ item.update({
+ "entities": self.reformat(entities_in_this_span),
+ "relations": self.reformat(relations_in_this_span),
+ })
+ if len(item['entities']) > 0:
+ item['entities']['label'] = [
+ self.entities_labels[x] for x in item['entities']['label']
+ ]
+ encoded_inputs_all.append(item)
+ if len(encoded_inputs_all) == 0:
+ return None
+ return encoded_inputs_all[0]
+
+ def reformat(self, data):
+ new_data = defaultdict(list)
+ for item in data:
+ for k, v in item.items():
+ new_data[k].append(v)
+ return new_data
+
+
+class ExpandDim(object):
+ def __call__(self, data):
+ for idx in range(len(data)):
+ if isinstance(data[idx], np.ndarray):
+ data[idx] = np.expand_dims(data[idx], axis=0)
+ else:
+ data[idx] = [data[idx]]
+ return data
+
+
+class Resize(object):
+ def __init__(self, size=(640, 640), **kwargs):
+ self.size = size
+
+ def resize_image(self, img):
+ resize_h, resize_w = self.size
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
+ ratio_h = float(resize_h) / ori_h
+ ratio_w = float(resize_w) / ori_w
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ return img, [ratio_h, ratio_w]
+
+ def __call__(self, data):
+ img = data['image']
+ if 'polys' in data:
+ text_polys = data['polys']
+
+ img_resize, [ratio_h, ratio_w] = self.resize_image(img)
+ if 'polys' in data:
+ new_boxes = []
+ for box in text_polys:
+ new_box = []
+ for cord in box:
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
+ new_boxes.append(new_box)
+ data['polys'] = np.array(new_boxes, dtype=np.float32)
+ data['image'] = img_resize
+ return data
+
+
+class ReInput(object):
+ def __init__(self, entities_labels=None, **kwargs):
+ if entities_labels is None:
+ self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+ else:
+ self.entities_labels = entities_labels
+
+ def __call__(self, data):
+ ser_inputs = data['ser_inputs']
+ ser_preds = data['ser_preds']
+ max_seq_len = ser_inputs[0].shape[0]
+ entities = ser_inputs[8]
+ assert len(entities) == len(ser_preds)
+
+ # entities
+ start = []
+ end = []
+ label = []
+ entity_idx_dict = {}
+ for i, (pred, entity) in enumerate(zip(ser_preds, entities)):
+ if pred == 'O':
+ continue
+ entity_idx_dict[len(start)] = i
+ start.append(entity['start'])
+ end.append(entity['end'])
+ label.append(self.entities_labels[pred])
+
+ entities = np.full([max_seq_len + 1, 3], fill_value=-1)
+ entities[0, 0] = len(start)
+ entities[1:len(start) + 1, 0] = start
+ entities[0, 1] = len(end)
+ entities[1:len(end) + 1, 1] = end
+ entities[0, 2] = len(label)
+ entities[1:len(label) + 1, 2] = label
+
+ # relations
+ head = []
+ tail = []
+ for i in range(len(label)):
+ for j in range(len(label)):
+ if label[i] == 1 and label[j] == 2:
+ head.append(i)
+ tail.append(j)
+
+ relations = np.full([len(head) + 1, 2], fill_value=-1)
+ relations[0, 0] = len(head)
+ relations[1:len(head) + 1, 0] = head
+ relations[0, 1] = len(tail)
+ relations[1:len(tail) + 1, 1] = tail
+
+ # remove ocr_info segment_offset_id and label in ser input
+ if isinstance(ser_inputs[0], paddle.Tensor):
+ entities = paddle.to_tensor(entities)
+ relations = paddle.to_tensor(relations)
+ ser_inputs = ser_inputs[:5] + [entities, relations, entity_idx_dict]
+
+ return ser_inputs
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/__init__.py b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b304060c09bb7fff5fafae44a215471c9069a35a
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .inference import PPStructureTableStructureOp
+
+__all__ = ['PPStructureTableStructureOp']
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/inference.py b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f49b8220ff920304289399570f0dc8d161c314fd
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/inference.py
@@ -0,0 +1,117 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import importlib
+import math
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+from ppcv.ops.models.base import ModelBaseOp
+
+from ppcv.ops.models.ocr.ocr_db_detection.preprocess import NormalizeImage, ToCHWImage, KeepKeys, ExpandDim, RGB2BGR
+from ppcv.ops.models.ocr.ocr_table_recognition.preprocess import *
+from ppcv.ops.models.ocr.ocr_table_recognition.postprocess import *
+
+
+@register
+class PPStructureTableStructureOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(PPStructureTableStructureOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+ self.batch_size = model_cfg["batch_size"]
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["structures", "dt_bboxes", "scores"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, result, shape_list):
+ outputs = result
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs, shape_list, self.output_keys)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def infer(self, image_list):
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ inputs = [
+ self.preprocess({
+ 'image': img
+ }) for img in batch_image_list
+ ]
+ shape_list = np.stack([x['shape'] for x in inputs])
+ inputs = np.concatenate([x['image'] for x in inputs], axis=0)
+ # model inference
+ result = self.predictor.run(inputs)
+ # postprocess
+ result = self.postprocess(result, shape_list)
+ results.extend(result)
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ key = self.input_keys[0]
+ is_list = False
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ is_list = True
+ else:
+ inputs = [[input[key]] for input in inputs]
+ sub_index_list = [len(input) for input in inputs]
+ inputs = reduce(lambda x, y: x.extend(y) or x, inputs)
+
+ pipe_outputs = []
+ if len(inputs) == 0:
+ pipe_outputs.append({
+ self.output_keys[0]: [],
+ self.output_keys[1]: [],
+ self.output_keys[2]: [],
+ })
+ return pipe_outputs
+ # step2: run
+ outputs = self.infer(inputs)
+ # step3: merge
+ curr_offsef_id = 0
+ for idx in range(len(sub_index_list)):
+ sub_start_idx = curr_offsef_id
+ sub_end_idx = curr_offsef_id + sub_index_list[idx]
+ output = outputs[sub_start_idx:sub_end_idx]
+ output = {k: [o[k] for o in output] for k in output[0]}
+ if is_list is not True:
+ output = {k: output[k][0] for k in output}
+ pipe_outputs.append(output)
+
+ curr_offsef_id = sub_end_idx
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/postprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..dec11ddc55d3220a921c9e010ada0b97c672bc19
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/postprocess.py
@@ -0,0 +1,219 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import re
+
+from ppcv.utils.download import get_dict_path
+
+import numpy as np
+import paddle
+
+
+class TableLabelDecode(object):
+ """ """
+
+ def __init__(self,
+ character_dict_path,
+ merge_no_span_structure=False,
+ **kwargs):
+ dict_character = []
+ character_dict_path = get_dict_path(character_dict_path)
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ dict_character.append(line)
+
+ if merge_no_span_structure:
+ if " | " not in dict_character:
+ dict_character.append(" | ")
+ if "" in dict_character:
+ dict_character.remove(" | ")
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+ self.td_token = [' | ', ' | | ']
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = dict_character
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
+
+ def __call__(self, result, shape_list, output_keys):
+ structure_probs = result[1]
+ bbox_preds = result[0]
+ if isinstance(structure_probs, paddle.Tensor):
+ structure_probs = structure_probs.numpy()
+ if isinstance(bbox_preds, paddle.Tensor):
+ bbox_preds = bbox_preds.numpy()
+ result = self.decode(structure_probs, bbox_preds, shape_list)
+ result_list = []
+ for i in range(len(shape_list)):
+ result_list.append({
+ output_keys[0]: result['structure_batch_list'][i][0],
+ output_keys[1]: result['bbox_batch_list'][i],
+ output_keys[2]: result['structure_batch_list'][i][1],
+ })
+ return result_list
+
+ def decode(self, structure_probs, bbox_preds, shape_list):
+ """convert text-label into text-index.
+ """
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_idx = structure_probs.argmax(axis=2)
+ structure_probs = structure_probs.max(axis=2)
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ score_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+ if char_idx in ignored_tokens:
+ continue
+ text = self.character[char_idx]
+ if text in self.td_token:
+ bbox = bbox_preds[batch_idx, idx]
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+ bbox_list.append(bbox)
+ structure_list.append(text)
+ score_list.append(structure_probs[batch_idx, idx])
+ structure_batch_list.append(
+ [structure_list, np.mean(score_list).tolist()])
+ bbox_batch_list.append(np.array(bbox_list).tolist())
+ result = {
+ 'bbox_batch_list': bbox_batch_list,
+ 'structure_batch_list': structure_batch_list,
+ }
+ return result
+
+ def decode_label(self, batch):
+ """convert text-label into text-index.
+ """
+ structure_idx = batch[1]
+ gt_bbox_list = batch[2]
+ shape_list = batch[-1]
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+ if char_idx in ignored_tokens:
+ continue
+ structure_list.append(self.character[char_idx])
+
+ bbox = gt_bbox_list[batch_idx][idx]
+ if bbox.sum() != 0:
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+ bbox_list.append(bbox)
+ structure_batch_list.append(structure_list)
+ bbox_batch_list.append(bbox_list)
+ result = {
+ 'bbox_batch_list': bbox_batch_list,
+ 'structure_batch_list': structure_batch_list,
+ }
+ return result
+
+ def _bbox_decode(self, bbox, shape):
+ h, w = shape[:2]
+ bbox[0::2] *= w
+ bbox[1::2] *= h
+ return bbox
+
+
+class TableMasterLabelDecode(TableLabelDecode):
+ """ """
+
+ def __init__(self,
+ character_dict_path,
+ box_shape='ori',
+ merge_no_span_structure=True,
+ **kwargs):
+ super(TableMasterLabelDecode, self).__init__(character_dict_path,
+ merge_no_span_structure)
+ self.box_shape = box_shape
+ assert box_shape in [
+ 'ori', 'pad'
+ ], 'The shape used for box normalization must be ori or pad'
+
+ def add_special_char(self, dict_character):
+ self.beg_str = ''
+ self.end_str = ''
+ self.unknown_str = ''
+ self.pad_str = ''
+ dict_character = dict_character
+ dict_character = dict_character + [
+ self.unknown_str, self.beg_str, self.end_str, self.pad_str
+ ]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ pad_idx = self.dict[self.pad_str]
+ start_idx = self.dict[self.beg_str]
+ end_idx = self.dict[self.end_str]
+ unknown_idx = self.dict[self.unknown_str]
+ return [start_idx, end_idx, pad_idx, unknown_idx]
+
+ def _bbox_decode(self, bbox, shape):
+ h, w, ratio_h, ratio_w, pad_h, pad_w = shape
+ if self.box_shape == 'pad':
+ h, w = pad_h, pad_w
+ bbox[0::2] *= w
+ bbox[1::2] *= h
+ bbox[0::2] /= ratio_w
+ bbox[1::2] /= ratio_h
+ x, y, w, h = bbox
+ x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
+ bbox = np.array([x1, y1, x2, y2])
+ return
diff --git a/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/preprocess.py b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..62d44e2a73810cffe6235b80243c7e33d2f578de
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/ocr/ocr_table_recognition/preprocess.py
@@ -0,0 +1,59 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import cv2
+import numpy as np
+
+
+class ResizeTableImage(object):
+ def __init__(self, max_len, **kwargs):
+ super(ResizeTableImage, self).__init__()
+ self.max_len = max_len
+
+ def __call__(self, data):
+ img = data['image']
+ height, width = img.shape[0:2]
+ ratio = self.max_len / (max(height, width) * 1.0)
+ resize_h = int(height * ratio)
+ resize_w = int(width * ratio)
+ resize_img = cv2.resize(img, (resize_w, resize_h))
+ data['image'] = resize_img
+ data['shape'] = np.array([height, width, ratio, ratio])
+ data['max_len'] = self.max_len
+ return data
+
+
+class PaddingTableImage(object):
+ def __init__(self, size, **kwargs):
+ super(PaddingTableImage, self).__init__()
+ self.size = size
+
+ def __call__(self, data):
+ img = data['image']
+ pad_h, pad_w = self.size
+ padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
+ height, width = img.shape[0:2]
+ padding_img[0:height, 0:width, :] = img.copy()
+ data['image'] = padding_img
+ shape = data['shape'].tolist()
+ shape.extend([pad_h, pad_w])
+ data['shape'] = np.array(shape)
+ return data
diff --git a/paddlecv/ppcv/ops/models/segmentation/__init__.py b/paddlecv/ppcv/ops/models/segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..acaa7824eb8a6d705314262af490b54538b530d3
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/segmentation/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+
+import importlib
+
+from .inference import SegmentationOp
+
+__all__ = ['SegmentationOp']
diff --git a/paddlecv/ppcv/ops/models/segmentation/inference.py b/paddlecv/ppcv/ops/models/segmentation/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..32adb76f308111e436482a2d9bedaf578bfc144e
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/segmentation/inference.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+from functools import reduce
+import os
+import numpy as np
+import math
+import paddle
+from ..base import ModelBaseOp
+
+from ppcv.ops.base import create_operators
+from ppcv.core.workspace import register
+
+from .preprocess import *
+from .postprocess import *
+
+
+@register
+class SegmentationOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(SegmentationOp, self).__init__(model_cfg, env_cfg)
+ mod = importlib.import_module(__name__)
+ self.preprocessor = create_operators(model_cfg["PreProcess"], mod)
+ self.postprocessor = create_operators(model_cfg["PostProcess"], mod)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ["seg_map"]
+
+ def preprocess(self, inputs):
+ outputs = inputs
+ for ops in self.preprocessor:
+ outputs = ops(outputs)
+ return outputs
+
+ def postprocess(self, inputs):
+ outputs = inputs
+ for idx, ops in enumerate(self.postprocessor):
+ if idx == len(self.postprocessor) - 1:
+ outputs = ops(outputs, self.output_keys)
+ else:
+ outputs = ops(outputs)
+ return outputs
+
+ def infer(self, image_list):
+ inputs = []
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ inputs = [self.preprocess(img) for img in batch_image_list]
+ inputs = np.concatenate(inputs, axis=0)
+ # model inference
+ result = self.predictor.run(inputs)[0]
+ # postprocess
+ result = self.postprocess(result)
+ results.extend(result)
+ # results = self.merge_batch_result(results)
+ return results
+
+ def __call__(self, inputs):
+ """
+ step1: parser inputs
+ step2: run
+ step3: merge results
+ input: a list of dict
+ """
+ key = self.input_keys[0]
+ is_list = False
+ if isinstance(inputs[0][key], (list, tuple)):
+ inputs = [input[key] for input in inputs]
+ is_list = True
+ else:
+ inputs = [[input[key]] for input in inputs]
+ sub_index_list = [len(input) for input in inputs]
+ inputs = reduce(lambda x, y: x.extend(y) or x, inputs)
+
+ # step2: run
+ outputs = self.infer(inputs)
+
+ # step3: merge
+ curr_offsef_id = 0
+ pipe_outputs = []
+ for idx in range(len(sub_index_list)):
+ sub_start_idx = curr_offsef_id
+ sub_end_idx = curr_offsef_id + sub_index_list[idx]
+ output = outputs[sub_start_idx:sub_end_idx]
+ output = {k: [o[k] for o in output] for k in output[0]}
+ if is_list is not True:
+ output = {k: output[k][0] for k in output}
+ pipe_outputs.append(output)
+
+ curr_offsef_id = sub_end_idx
+
+ return pipe_outputs
diff --git a/paddlecv/ppcv/ops/models/segmentation/postprocess.py b/paddlecv/ppcv/ops/models/segmentation/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..215c880f2f74313fef7ea8d417b43a8685079b53
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/segmentation/postprocess.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+
+__all__ = ["SegPostProcess"]
+
+
+class SegPostProcess(object):
+ def __init__(self):
+ super().__init__()
+
+ def __call__(self, inputs, output_keys):
+ outputs = [{output_keys[0]: seg_map} for seg_map in inputs]
+ return outputs
diff --git a/paddlecv/ppcv/ops/models/segmentation/preprocess.py b/paddlecv/ppcv/ops/models/segmentation/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b518a66699b2a24f21e092343119bf003351f9cd
--- /dev/null
+++ b/paddlecv/ppcv/ops/models/segmentation/preprocess.py
@@ -0,0 +1,239 @@
+"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from functools import partial
+import six
+import math
+import random
+import cv2
+import numpy as np
+import importlib
+from PIL import Image
+
+__all__ = ["Resize", "ResizeByShort", "Normalize", "ToCHWImage", "ExpandDim"]
+
+
+class ResizeBase(object):
+ """
+ The base class of resize.
+ """
+ # The interpolation mode
+ interp_dict = {
+ 'NEAREST': cv2.INTER_NEAREST,
+ 'LINEAR': cv2.INTER_LINEAR,
+ 'CUBIC': cv2.INTER_CUBIC,
+ 'AREA': cv2.INTER_AREA,
+ 'LANCZOS4': cv2.INTER_LANCZOS4
+ }
+
+ def __init__(self, size_divisor=None, interp='LINEAR'):
+ if size_divisor is not None:
+ assert isinstance(size_divisor,
+ int), "size_divisor should be None or int"
+ if interp not in self.interp_dict:
+ raise ValueError("`interp` should be one of {}".format(
+ self.interp_dict.keys()))
+
+ self.size_divisor = size_divisor
+ self.interp = interp
+
+ @staticmethod
+ def resize(im, target_size, interp):
+ if isinstance(target_size, (list, tuple)):
+ w = target_size[0]
+ h = target_size[1]
+ elif isinstance(target_size, int):
+ w = target_size
+ h = target_size
+ else:
+ raise ValueError(
+ "target_size should be int (wh, wh), list (w, h) or tuple (w, h)"
+ )
+ im = cv2.resize(im, (w, h), interpolation=interp)
+ return im
+
+ @staticmethod
+ def rescale_size(img_size, target_size):
+ scale = min(
+ max(target_size) / max(img_size), min(target_size) / min(img_size))
+ rescaled_size = [round(i * scale) for i in img_size]
+ return rescaled_size, scale
+
+ def __call__(self, img):
+ raise NotImplementedError
+
+
+class Resize(ResizeBase):
+ """
+ Resize an image.
+
+ Args:
+ target_size (list|tuple, optional): The target size (w, h) of image. Default: (512, 512).
+ keep_ratio (bool, optional): Whether to keep the same ratio for width and height in resizing.
+ Default: False.
+ size_divisor (int, optional): If size_divisor is not None, make the width and height be the times
+ of size_divisor. Default: None.
+ interp (str, optional): The interpolation mode of resize is consistent with opencv.
+ ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']. Note that when it is
+ 'RANDOM', a random interpolation mode would be specified. Default: "LINEAR".
+ """
+
+ def __init__(self,
+ target_size=(512, 512),
+ keep_ratio=False,
+ size_divisor=None,
+ interp='LINEAR'):
+ super().__init__(size_divisor=size_divisor, interp=interp)
+
+ if isinstance(target_size, list) or isinstance(target_size, tuple):
+ if len(target_size) != 2:
+ raise ValueError(
+ '`target_size` should include 2 elements, but it is {}'.
+ format(target_size))
+ else:
+ raise TypeError(
+ "Type of `target_size` is invalid. It should be list or tuple, but it is {}"
+ .format(type(target_size)))
+
+ self.target_size = target_size
+ self.keep_ratio = keep_ratio
+
+ def __call__(self, img):
+ target_size = self.target_size
+ if self.keep_ratio:
+ h, w = img.shape[0:2]
+ target_size, _ = self.rescale_size((w, h), self.target_size)
+ if self.size_divisor:
+ target_size = [
+ math.ceil(i / self.size_divisor) * self.size_divisor
+ for i in target_size
+ ]
+
+ img = self.resize(img, target_size, self.interp_dict[self.interp])
+ return img
+
+
+class ResizeByShort(ResizeBase):
+ """
+ Resize an image by short.
+
+ Args:
+ target_size (list|tuple, optional): The target size (w, h) of image. Default: (512, 512).
+ size_divisor (int, optional): If size_divisor is not None, make the width and height be the times
+ of size_divisor. Default: None.
+ interp (str, optional): The interpolation mode of resize is consistent with opencv.
+ ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']. Note that when it is
+ 'RANDOM', a random interpolation mode would be specified. Default: "LINEAR".
+ """
+
+ def __init__(self, resize_short=512, size_divisor=None, interp='LINEAR'):
+ super().__init__(size_divisor=size_divisor, interp=interp)
+
+ self.resize_short = resize_short
+
+ def __call__(self, img):
+ h, w = img.shape[:2]
+ scale = self.resize_short / min(h, w)
+ h_resize = round(h * scale)
+ w_resize = round(w * scale)
+ if self.size_divisor is not None:
+ h_resize = math.ceil(h_resize /
+ self.size_divisor) * self.size_divisor
+ w_resize = math.ceil(w_resize /
+ self.size_divisor) * self.size_divisor
+
+ img = self.resize(img, (w_resize, h_resize),
+ self.interp_dict[self.interp])
+ return img
+
+
+class Normalize(object):
+ """ normalize image such as substract mean, divide std
+ """
+
+ def __init__(self,
+ scale=None,
+ mean=None,
+ std=None,
+ order='chw',
+ output_fp16=False,
+ channel_num=3):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ assert channel_num in [
+ 3, 4
+ ], "channel number of input image should be set to 3 or 4."
+ self.channel_num = channel_num
+ self.output_dtype = 'float16' if output_fp16 else 'float32'
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ self.order = order
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype('float32')
+ self.std = np.array(std).reshape(shape).astype('float32')
+
+ def __call__(self, img):
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+
+ assert isinstance(img,
+ np.ndarray), "invalid input 'img' in NormalizeImage"
+
+ img = (img.astype('float32') * self.scale - self.mean) / self.std
+
+ if self.channel_num == 4:
+ img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
+ img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
+ pad_zeros = np.zeros(
+ (1, img_h, img_w)) if self.order == 'chw' else np.zeros(
+ (img_h, img_w, 1))
+ img = (np.concatenate(
+ (img, pad_zeros), axis=0)
+ if self.order == 'chw' else np.concatenate(
+ (img, pad_zeros), axis=2))
+ return img.astype(self.output_dtype)
+
+
+class ToCHWImage(object):
+ """ convert hwc image to chw image
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, img):
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+
+ return img.transpose((2, 0, 1))
+
+
+class ExpandDim(object):
+ def __init__(self, axis=0):
+ self.axis = axis
+
+ def __call__(self, img):
+ img = np.expand_dims(img, axis=self.axis)
+ return img
diff --git a/paddlecv/ppcv/ops/output/__init__.py b/paddlecv/ppcv/ops/output/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cbfa93eca3e24853b9f14ca88ee93dd91edc14e
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import OutputBaseOp
+from .classification import ClasOutput
+from .feature_extraction import FeatureOutput
+from .detection import DetOutput
+from .keypoint import KptOutput
+from .ocr import OCRTableOutput, OCROutput, PPStructureOutput, PPStructureReOutput, PPStructureSerOutput
+from .segmentation import SegOutput, HumanSegOutput, MattingOutput
+from .tracker import TrackerOutput
+
+__all__ = [
+ 'OutputBaseOp', 'ClasOutput', 'FeatureOutput', 'DetOutput', 'KptOutput',
+ 'SegOutput', 'HumanSegOutput', 'MattingOutput', 'OCROutput',
+ 'OCRTableOutput', 'PPStructureOutput', 'PPStructureReOutput',
+ 'PPStructureSerOutput', 'TrackerOutput'
+]
diff --git a/paddlecv/ppcv/ops/output/base.py b/paddlecv/ppcv/ops/output/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3228df9f0f477e74c7a57a817805246464728aa7
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/base.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+from collections import defaultdict
+
+from ppcv.ops.base import BaseOp
+
+
+class OutputBaseOp(BaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(OutputBaseOp, self).__init__(model_cfg, env_cfg)
+ self.output_dir = self.env_cfg.get('output_dir', 'output')
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir, exist_ok=True)
+ self.save_img = self.env_cfg.get('save_img', False)
+ self.save_res = self.env_cfg.get('save_res', False)
+ self.return_res = self.env_cfg.get('return_res', False)
+ self.print_res = self.env_cfg.get('print_res', False)
+
+ @classmethod
+ def type(self):
+ return 'OUTPUT'
+
+ def __call__(self, inputs):
+ return
diff --git a/paddlecv/ppcv/ops/output/classification.py b/paddlecv/ppcv/ops/output/classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4760b48cc317c832ed5ff0c7fc193150d9c3c77
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/classification.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+import json
+from collections import defaultdict
+from .base import OutputBaseOp
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+
+logger = setup_logger('ClasOutput')
+
+
+@register
+class ClasOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(ClasOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for res in inputs:
+ fn, image, class_ids, scores, label_names = res.values()
+ image = res.pop('input.image')
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ image = image[:, :, ::-1]
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'clas_output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
diff --git a/paddlecv/ppcv/ops/output/detection.py b/paddlecv/ppcv/ops/output/detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..8baeefdc9f99c5ae9d5587ce7759b9e0fd1316fd
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/detection.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+import json
+from collections import defaultdict
+from .base import OutputBaseOp
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+from PIL import Image, ImageDraw, ImageFile
+
+logger = setup_logger('DetOutput')
+
+
+def get_id_color(idx):
+ idx = idx * 3
+ color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
+ return color
+
+
+def get_color_map_list(num_classes):
+ """
+ Args:
+ num_classes (int): number of class
+ Returns:
+ color_map (list): RGB color list
+ """
+ color_map = num_classes * [0, 0, 0]
+ for i in range(0, num_classes):
+ j = 0
+ lab = i
+ while lab:
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+ j += 1
+ lab >>= 3
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
+ return color_map
+
+
+def draw_det(image, dt_bboxes, dt_scores, dt_cls_names, input_id=None):
+ im = Image.fromarray(image[:, :, ::-1])
+ draw_thickness = min(im.size) // 320
+ draw = ImageDraw.Draw(im)
+ name_set = sorted(set(dt_cls_names))
+ name2clsid = {name: i for i, name in enumerate(name_set)}
+ clsid2color = {}
+ color_list = get_color_map_list(len(name_set))
+
+ for i in range(len(dt_bboxes)):
+ box, score, name = dt_bboxes[i], dt_scores[i], dt_cls_names[i]
+ if input_id is None:
+ color = tuple(color_list[name2clsid[name]])
+ else:
+ color = get_id_color(input_id[i])
+
+ xmin, ymin, xmax, ymax = box
+ # draw bbox
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+ (xmin, ymin)],
+ width=draw_thickness,
+ fill=color)
+
+ # draw label
+ text = "{} {:.4f}".format(name, score)
+ box = draw.textbbox((xmin, ymin), text, anchor='lt')
+ draw.rectangle(box, fill=color)
+ draw.text((box[0], box[1]), text, fill=(255, 255, 255))
+ image = np.array(im)
+ return image
+
+
+@register
+class DetOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(DetOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for res in inputs:
+ fn, image, dt_bboxes, dt_scores, dt_cls_names = res.values()
+ image = draw_det(image, dt_bboxes, dt_scores, dt_cls_names)
+ res.pop('input.image')
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'det_output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
diff --git a/paddlecv/ppcv/ops/output/feature_extraction.py b/paddlecv/ppcv/ops/output/feature_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..6336fa0adc8410b1b3c7f6579968659b6d60920f
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/feature_extraction.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import glob
+import copy
+import math
+import json
+
+import numpy as np
+import paddle
+import cv2
+
+from collections import defaultdict
+from .base import OutputBaseOp
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+
+logger = setup_logger('FeatureOutput')
+
+
+@register
+class FeatureOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super().__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for res in inputs:
+ # TODO(gaotingquan): video input is not tested
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ if self.print_res:
+ msg = " ".join([f"{key}: {res[key]}" for key in res])
+ logger.info(msg)
+ if self.return_res:
+ total_res.append(res)
+ if self.return_res:
+ return total_res
+ return
diff --git a/paddlecv/ppcv/ops/output/keypoint.py b/paddlecv/ppcv/ops/output/keypoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f804f64e778ef2907025b3c99c017d134f8411fd
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/keypoint.py
@@ -0,0 +1,149 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+import json
+from collections import defaultdict
+from .base import OutputBaseOp
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+
+logger = setup_logger('KptOutput')
+
+
+def get_color(idx):
+ idx = idx * 3
+ color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
+ return color
+
+
+def draw_kpt(image, keypoints, visual_thresh=0.6, ids=None):
+ try:
+ import matplotlib.pyplot as plt
+ import matplotlib
+ plt.switch_backend('agg')
+ except Exception as e:
+ print('Matplotlib not found, please install matplotlib.'
+ 'for example: `pip install matplotlib`.')
+ raise e
+ image = image[:, :, ::-1]
+ skeletons = np.array(keypoints)[0]
+ kpt_nums = 17
+ if len(skeletons) > 0:
+ kpt_nums = skeletons.shape[1]
+ if kpt_nums == 17: #plot coco keypoint
+ EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7),
+ (6, 8), (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
+ (13, 15), (14, 16), (11, 12)]
+ else: #plot mpii keypoint
+ EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7),
+ (7, 8), (8, 9), (10, 11), (11, 12), (13, 14), (14, 15),
+ (8, 12), (8, 13)]
+ NUM_EDGES = len(EDGES)
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+ cmap = matplotlib.cm.get_cmap('hsv')
+ plt.figure()
+
+ color_set = None
+
+ canvas = image.copy()
+ for i in range(kpt_nums):
+ for j in range(len(skeletons)):
+ if skeletons[j][i, 2] < visual_thresh:
+ continue
+ if ids is None:
+ color = colors[i] if color_set is None else colors[color_set[j]
+ %
+ len(colors)]
+ else:
+ color = get_color(ids[j])
+
+ cv2.circle(
+ canvas,
+ tuple(skeletons[j][i, 0:2].astype('int32')),
+ 2,
+ color,
+ thickness=-1)
+
+ to_plot = cv2.addWeighted(image, 0.3, canvas, 0.7, 0)
+ fig = matplotlib.pyplot.gcf()
+
+ stickwidth = 2
+
+ for i in range(NUM_EDGES):
+ for j in range(len(skeletons)):
+ edge = EDGES[i]
+ if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[
+ 1], 2] < visual_thresh:
+ continue
+
+ cur_canvas = canvas.copy()
+ X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
+ Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)),
+ (int(length / 2), stickwidth),
+ int(angle), 0, 360, 1)
+ if ids is None:
+ color = colors[i] if color_set is None else colors[color_set[j]
+ %
+ len(colors)]
+ else:
+ color = get_color(ids[j])
+ cv2.fillConvexPoly(cur_canvas, polygon, color)
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+ return canvas
+
+
+@register
+class KptOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(KptOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for res in inputs:
+ fn, image, keypoints, kpt_scores = res.values()
+ res.pop('input.image')
+ image = draw_kpt(image, keypoints)
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'kpt_output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
diff --git a/paddlecv/ppcv/ops/output/ocr.py b/paddlecv/ppcv/ops/output/ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f19bee9c4bb2972e82e34cebd8b908f51771365
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/ocr.py
@@ -0,0 +1,392 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import cv2
+import json
+import random
+import math
+from PIL import Image, ImageDraw, ImageFont
+from .base import OutputBaseOp
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+from ppcv.utils.download import get_font_path
+
+logger = setup_logger('OCROutput')
+
+
+def draw_boxes(img, boxes):
+ boxes = np.array(boxes)
+ img_show = img.copy()
+ for box in boxes.astype(int):
+ if len(box) == 4:
+ x1, y1, x2, y2 = box
+ cv2.rectangle(img_show, (x1, y1), (x2, y2), (0, 0, 255), 2)
+ else:
+ box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
+ cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
+ return img_show
+
+
+@register
+class OCRTableOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(OCRTableOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, image, dt_bboxes, structures, scores = list(input.values())[:5]
+ res = dict(
+ filename=fn,
+ dt_bboxes=dt_bboxes,
+ structures=structures,
+ scores=scores)
+ if 'Matcher.html' in input:
+ res.update(html=input['Matcher.html'])
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ image = draw_boxes(image[:, :, ::-1], dt_bboxes)
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
+
+
+@register
+class OCROutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(OCROutput, self).__init__(model_cfg, env_cfg)
+ font_path = model_cfg.get('font_path', None)
+ self.font_path = get_font_path(font_path)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, image, dt_polys = list(input.values())[:3]
+ rec_text = input.get('rec.rec_text', None)
+ rec_score = input.get('rec.rec_score', None)
+ res = dict(
+ filename=fn,
+ dt_polys=dt_polys.tolist(),
+ rec_text=rec_text,
+ rec_score=rec_score)
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ image = image[:, :, ::-1]
+ if rec_text is not None:
+ image = self.draw_ocr_box_txt(
+ Image.fromarray(image), dt_polys, rec_text, rec_score)
+ else:
+ image = draw_boxes(image, dt_polys.reshape([-1, 8]))
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
+
+ def draw_ocr_box_txt(self,
+ image,
+ boxes,
+ txts=None,
+ scores=None,
+ drop_score=0.5):
+ h, w = image.height, image.width
+ img_left = image.copy()
+ img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+ random.seed(0)
+
+ draw_left = ImageDraw.Draw(img_left)
+ if txts is None or len(txts) != len(boxes):
+ txts = [None] * len(boxes)
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
+ if scores is not None and scores[idx] < drop_score:
+ continue
+ color = (random.randint(0, 255), random.randint(0, 255),
+ random.randint(0, 255))
+ draw_left.polygon(box, fill=color)
+ img_right_text = self.draw_box_txt_fine((w, h), box, txt)
+ pts = np.array(box, np.int32).reshape((-1, 1, 2))
+ cv2.polylines(img_right_text, [pts], True, color, 1)
+ img_right = cv2.bitwise_and(img_right, img_right_text)
+ img_left = Image.blend(image, img_left, 0.5)
+ img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
+ img_show.paste(img_left, (0, 0, w, h))
+ img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+ return np.array(img_show)
+
+ def draw_box_txt_fine(self, img_size, box, txt):
+ box_height = int(
+ math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
+ box_width = int(
+ math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
+
+ if box_height > 2 * box_width and box_height > 30:
+ img_text = Image.new('RGB', (box_height, box_width),
+ (255, 255, 255))
+ draw_text = ImageDraw.Draw(img_text)
+ if txt:
+ font = self.create_font(txt, (box_height, box_width))
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+ img_text = img_text.transpose(Image.ROTATE_270)
+ else:
+ img_text = Image.new('RGB', (box_width, box_height),
+ (255, 255, 255))
+ draw_text = ImageDraw.Draw(img_text)
+ if txt:
+ font = self.create_font(txt, (box_width, box_height))
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+ pts1 = np.float32(
+ [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
+ pts2 = np.array(box, dtype=np.float32)
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+
+ img_text = np.array(img_text, dtype=np.uint8)
+ img_right_text = cv2.warpPerspective(
+ img_text,
+ M,
+ img_size,
+ flags=cv2.INTER_NEAREST,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(255, 255, 255))
+ return img_right_text
+
+ def create_font(self, txt, sz):
+ font_size = int(sz[1] * 0.99)
+ font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
+ length = font.getsize(txt)[0]
+ if length > sz[0]:
+ font_size = int(font_size * sz[0] / length)
+ font = ImageFont.truetype(
+ self.font_path, font_size, encoding="utf-8")
+ return font
+
+
+@register
+class PPStructureOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(PPStructureOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for res in inputs:
+ image = res.pop(self.input_keys[1])
+ res['concat.dt_polys'] = [
+ x.tolist() for x in res['concat.dt_polys']
+ ]
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
+
+
+@register
+class PPStructureSerOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(PPStructureSerOutput, self).__init__(model_cfg, env_cfg)
+ font_path = model_cfg.get('font_path', None)
+ self.font_path = get_font_path(font_path)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, image = list(input.values())[:2]
+ pred_ids = input.get(self.input_keys[2], None)
+ preds = input.get(self.input_keys[3], None)
+ dt_polys = input.get(self.input_keys[4], None)
+ rec_texts = input.get(self.input_keys[5], None)
+ res = dict(
+ filename=fn,
+ dt_polys=np.array(dt_polys).tolist(),
+ rec_text=rec_texts,
+ preds=preds,
+ pred_ids=pred_ids)
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ image = self.draw_ser_results(image, pred_ids, preds, dt_polys,
+ rec_texts)
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
+
+ def draw_ser_results(self, image, pred_ids, preds, dt_polys, rec_texts):
+ np.random.seed(2021)
+ color = (np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)))
+ color_map = {
+ idx: (color[0][idx], color[1][idx], color[2][idx])
+ for idx in range(1, 255)
+ }
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ elif isinstance(image, str) and os.path.isfile(image):
+ image = Image.open(image).convert('RGB')
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(self.font_path, 14, encoding="utf-8")
+ for pred_id, pred, dt_poly, rec_text in zip(pred_ids, preds, dt_polys,
+ rec_texts):
+ if pred_id not in color_map:
+ continue
+ color = color_map[pred_id]
+ text = "{}: {}".format(pred, rec_text)
+
+ bbox = self.trans_poly_to_bbox(dt_poly)
+ self.draw_box_txt(bbox, text, draw, font, 14, color)
+
+ img_new = Image.blend(image, img_new, 0.7)
+ return np.array(img_new)
+
+ def draw_box_txt(self, bbox, text, draw, font, font_size, color):
+
+ # draw ocr results outline
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ tw = font.getsize(text)[0]
+ th = font.getsize(text)[1]
+ start_y = max(0, bbox[0][1] - th)
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + th)],
+ fill=(0, 0, 255))
+ draw.text(
+ (bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+ def trans_poly_to_bbox(self, poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
+@register
+class PPStructureReOutput(PPStructureSerOutput):
+ def __init__(self, model_cfg, env_cfg):
+ super(PPStructureReOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, image = list(input.values())[:2]
+ heads = input.get(self.input_keys[2], None)
+ tails = input.get(self.input_keys[3], None)
+ res = dict(filename=fn, heads=heads, tails=tails)
+ if self.frame_id != -1:
+ res.update({'frame_id': frame_id})
+ logger.info(res)
+ if self.save_img:
+ image = self.draw_re_results(image, heads, tails)
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ logger.info('Save output image to {}'.format(out_path))
+ cv2.imwrite(out_path, image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ res_file_name = 'output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ logger.info('Save output result to {}'.format(out_path))
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ if self.return_res:
+ return total_res
+ return
+
+ def draw_re_results(self, image, heads, tails):
+ font_size = 18
+ np.random.seed(0)
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ elif isinstance(image, str) and os.path.isfile(image):
+ image = Image.open(image).convert('RGB')
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(self.font_path, font_size, encoding="utf-8")
+ color_head = (0, 0, 255)
+ color_tail = (255, 0, 0)
+ color_line = (0, 255, 0)
+
+ for ocr_info_head, ocr_info_tail in zip(heads, tails):
+ head_bbox = self.trans_poly_to_bbox(ocr_info_head["dt_polys"])
+ tail_bbox = self.trans_poly_to_bbox(ocr_info_tail["dt_polys"])
+ self.draw_box_txt(head_bbox, ocr_info_head["rec_text"], draw, font,
+ font_size, color_head)
+ self.draw_box_txt(tail_bbox, ocr_info_tail["rec_text"], draw, font,
+ font_size, color_tail)
+
+ center_head = ((head_bbox[0] + head_bbox[2]) // 2,
+ (head_bbox[1] + head_bbox[3]) // 2)
+ center_tail = ((tail_bbox[0] + tail_bbox[2]) // 2,
+ (tail_bbox[1] + tail_bbox[3]) // 2)
+
+ draw.line([center_head, center_tail], fill=color_line, width=5)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
diff --git a/paddlecv/ppcv/ops/output/segmentation.py b/paddlecv/ppcv/ops/output/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..80481ff12e588f67480023d06c16cd39d9ae6294
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/segmentation.py
@@ -0,0 +1,183 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import math
+import glob
+import json
+from collections import defaultdict
+
+import cv2
+import paddle
+import numpy as np
+from PIL import Image
+
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+
+from .base import OutputBaseOp
+
+logger = setup_logger('SegOutput')
+
+
+@register
+class SegOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super().__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, _, seg_map = input.values()
+ res = dict(filename=fn, seg_map=seg_map.tolist())
+ if self.save_res or self.return_res:
+ total_res.append(res)
+
+ if self.save_img:
+ seg_map = get_pseudo_color_map(seg_map)
+ file_name = os.path.split(fn)[-1]
+ out_path = os.path.join(self.output_dir, file_name)
+ seg_map.save(out_path)
+ logger.info('Save output image to {}'.format(out_path))
+
+ if self.save_res:
+ res_file_name = 'seg_output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ logger.info('Save output result to {}'.format(out_path))
+
+ if self.return_res:
+ return total_res
+
+
+@register
+class HumanSegOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super().__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, img, seg_map = input.values()
+ res = dict(filename=fn, seg_map=seg_map.tolist())
+ if self.save_res or self.return_res:
+ total_res.append(res)
+
+ if self.save_img:
+ alpha = seg_map[1]
+ alpha = cv2.resize(alpha, (img.shape[1], img.shape[0]))
+ alpha = (alpha * 255).astype('uint8')
+ img = img[:, :, ::-1]
+ res_img = np.concatenate(
+ [img, alpha[:, :, np.newaxis]], axis=-1)
+
+ filename = os.path.basename(fn).split('.')[0]
+ out_path = os.path.join(self.output_dir, filename + ".png")
+ cv2.imwrite(out_path, res_img)
+ logger.info('Save output image to {}'.format(out_path))
+
+ if self.save_res:
+ res_file_name = 'humanseg_output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ logger.info('Save output result to {}'.format(out_path))
+
+ if self.return_res:
+ return total_res
+
+
+@register
+class MattingOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super().__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ for input in inputs:
+ fn, img, seg_map = input.values()
+ res = dict(filename=fn, seg_map=seg_map.tolist())
+ if self.save_res or self.return_res:
+ total_res.append(res)
+
+ if self.save_img:
+ alpha = seg_map.squeeze()
+ alpha = cv2.resize(alpha, (img.shape[1], img.shape[0]))
+ alpha = (alpha * 255).astype('uint8')
+
+ filename = os.path.basename(fn).split('.')[0]
+ out_path = os.path.join(self.output_dir, filename + ".png")
+ cv2.imwrite(out_path, alpha)
+ logger.info('Save output image to {}'.format(out_path))
+
+ if self.save_res:
+ res_file_name = 'matting_output.json'
+ out_path = os.path.join(self.output_dir, res_file_name)
+ with open(out_path, 'w') as f:
+ json.dump(total_res, f)
+ logger.info('Save output result to {}'.format(out_path))
+
+ if self.return_res:
+ return total_res
+
+
+def get_pseudo_color_map(pred, color_map=None):
+ """
+ Get the pseudo color image.
+
+ Args:
+ pred (numpy.ndarray): the origin predicted image.
+ color_map (list, optional): the palette color map. Default: None,
+ use paddleseg's default color map.
+
+ Returns:
+ (numpy.ndarray): the pseduo image.
+ """
+ pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P')
+ if color_map is None:
+ color_map = get_color_map_list(256)
+ pred_mask.putpalette(color_map)
+ return pred_mask
+
+
+def get_color_map_list(num_classes, custom_color=None):
+ """
+ Returns the color map for visualizing the segmentation mask,
+ which can support arbitrary number of classes.
+
+ Args:
+ num_classes (int): Number of classes.
+ custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map.
+
+ Returns:
+ (list). The color map.
+ """
+
+ num_classes += 1
+ color_map = num_classes * [0, 0, 0]
+ for i in range(0, num_classes):
+ j = 0
+ lab = i
+ while lab:
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+ j += 1
+ lab >>= 3
+ color_map = color_map[3:]
+
+ if custom_color:
+ color_map[:len(custom_color)] = custom_color
+ return color_map
diff --git a/paddlecv/ppcv/ops/output/tracker.py b/paddlecv/ppcv/ops/output/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..20334eb56d9a23cf9f9364fd0ae9156eb2666f73
--- /dev/null
+++ b/paddlecv/ppcv/ops/output/tracker.py
@@ -0,0 +1,105 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+import json
+from collections import defaultdict
+from .base import OutputBaseOp
+from .detection import draw_det
+from ppcv.utils.logger import setup_logger
+from ppcv.core.workspace import register
+from PIL import Image, ImageDraw, ImageFile
+
+logger = setup_logger('TrackerOutput')
+
+
+def write_mot_results(filename, results, data_type='mot', num_classes=1):
+ # support single and multi classes
+ if data_type in ['mot', 'mcmot']:
+ save_format = '{frame},{id},{x1},{y1},{w},{h},{score},{cls_id},-1,-1\n'
+ elif data_type == 'kitti':
+ save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
+ else:
+ raise ValueError(data_type)
+
+ frame_id, tk_bboxes, tk_scores, tk_ids, tk_cls_ids = results
+ frame_id = -1 if data_type == 'kitti' else frame_id
+ with open(filename, 'w') as f:
+ for bbox, score, tk_id, cls_id in zip(tk_bboxes, tk_scores, tk_ids,
+ tk_cls_ids):
+ if tk_id < 0: continue
+ if data_type == 'mot':
+ cls_id = -1
+
+ x1, y1, x2, y2 = bbox
+ w, h = x2 - x1, y2 - y1
+ line = save_format.format(
+ frame=frame_id,
+ id=tk_id,
+ x1=x1,
+ y1=y1,
+ x2=x2,
+ y2=y2,
+ w=w,
+ h=h,
+ score=score,
+ cls_id=cls_id)
+ f.write(line)
+
+
+@register
+class TrackerOutput(OutputBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(TrackerOutput, self).__init__(model_cfg, env_cfg)
+
+ def __call__(self, inputs):
+ total_res = []
+ vis_images = []
+ for res in inputs:
+ fn, image, tk_bboxes, tk_scores, tk_ids, tk_cls_ids, tk_cls_names = res.values(
+ )
+ tk_names = [
+ '{} {}'.format(tk_cls_name, tk_id)
+ for tk_id, tk_cls_name in zip(tk_ids, tk_cls_names)
+ ]
+ image = draw_det(image, tk_bboxes, tk_scores, tk_names, tk_ids)
+ res.pop('input.image')
+ if self.frame_id != -1:
+ res.update({'frame_id': self.frame_id})
+ logger.info(res)
+ if self.save_img:
+ vis_images.append(image)
+ if self.save_res or self.return_res:
+ total_res.append(res)
+ if self.save_res:
+ video_name = fn.split('/')[-1].split('.')[0]
+ output_dir = os.path.join(self.output_dir, video_name)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ out_path = os.path.join(output_dir, '{}.txt'.format(self.frame_id))
+ logger.info('Save output result to {}'.format(out_path))
+ write_mot_results(
+ out_path,
+ [self.frame_id, tk_bboxes, tk_scores, tk_ids, tk_cls_ids])
+ if self.return_res:
+ if vis_images:
+ for i, vis_im in enumerate(vis_images):
+ total_res[i].update({'output': vis_im})
+ return total_res
+ return
diff --git a/paddlecv/ppcv/ops/predictor.py b/paddlecv/ppcv/ops/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cb61a06445ecdb4e8456755a78cf7d67cbaabbd
--- /dev/null
+++ b/paddlecv/ppcv/ops/predictor.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import math
+import paddle
+from paddle.inference import Config
+from paddle.inference import create_predictor
+
+
+class PaddlePredictor(object):
+ def __init__(self, param_path, model_path, config, delete_pass=[]):
+ super().__init__()
+ self.predictor, self.inference_config, self.input_names, self.input_tensors, self.output_tensors = self.create_paddle_predictor(
+ param_path,
+ model_path,
+ batch_size=config['batch_size'],
+ run_mode=config.get("run_mode", "paddle"), # used trt or mkldnn
+ device=config.get("device", "CPU"),
+ min_subgraph_size=config["min_subgraph_size"],
+ shape_info_filename=config["shape_info_filename"],
+ trt_calib_mode=config["trt_calib_mode"],
+ cpu_threads=config["cpu_threads"],
+ trt_use_static=config["trt_use_static"],
+ delete_pass=delete_pass)
+
+ def create_paddle_predictor(self,
+ param_path,
+ model_path,
+ batch_size=1,
+ run_mode='paddle',
+ device='CPU',
+ min_subgraph_size=3,
+ shape_info_filename=None,
+ trt_calib_mode=False,
+ cpu_threads=6,
+ trt_use_static=False,
+ delete_pass=[]):
+ if not os.path.exists(model_path) or not os.path.exists(param_path):
+ raise ValueError(
+ f"inference model: {model_path} or param: {param_path} does not exist, please check again..."
+ )
+ assert run_mode in [
+ "paddle", "trt_fp32", "trt_fp16", "trt_int8", "mkldnn",
+ "mkldnn_bf16"
+ ], "The run_mode must be 'paddle', 'trt_fp32', 'trt_fp16', 'trt_int8', 'mkldnn', 'mkldnn_bf16', but received run_mode: {}".format(
+ run_mode)
+ config = Config(model_path, param_path)
+ if device == 'GPU':
+ config.enable_use_gpu(200, 0)
+ else:
+ config.disable_gpu()
+ if 'mkldnn' in run_mode:
+ try:
+ config.enable_mkldnn()
+ config.set_cpu_math_library_num_threads(cpu_threads)
+ if 'bf16' in run_mode:
+ config.enable_mkldnn_bfloat16()
+ except Exception as e:
+ print(
+ "The current environment does not support `mkldnn`, so disable mkldnn."
+ )
+ pass
+
+ precision_map = {
+ 'trt_int8': Config.Precision.Int8,
+ 'trt_fp32': Config.Precision.Float32,
+ 'trt_fp16': Config.Precision.Half
+ }
+ if run_mode in precision_map.keys():
+ config.enable_tensorrt_engine(
+ workspace_size=(1 << 25) * batch_size,
+ max_batch_size=batch_size,
+ min_subgraph_size=min_subgraph_size,
+ precision_mode=precision_map[run_mode],
+ trt_use_static=trt_use_static,
+ use_calib_mode=trt_calib_mode)
+
+ if shape_info_filename is not None:
+ if not os.path.exists(shape_info_filename):
+ config.collect_shape_range_info(shape_info_filename)
+ print(
+ f"collect dynamic shape info into : {shape_info_filename}"
+ )
+ else:
+ print(
+ f"dynamic shape info file( {shape_info_filename} ) already exists, not need to generate again."
+ )
+ config.enable_tuned_tensorrt_dynamic_shape(shape_info_filename,
+ True)
+
+ # disable print log when predict
+ config.disable_glog_info()
+ for del_p in delete_pass:
+ config.delete_pass(del_p)
+ # enable shared memory
+ config.enable_memory_optim()
+ config.switch_ir_optim(True)
+ # disable feed, fetch OP, needed by zero_copy_run
+ config.switch_use_feed_fetch_ops(False)
+ predictor = create_predictor(config)
+
+ # get input and output tensor property
+ input_names = predictor.get_input_names()
+ input_tensors = []
+ output_tensors = []
+ for input_name in input_names:
+ input_tensor = predictor.get_input_handle(input_name)
+ input_tensors.append(input_tensor)
+ output_names = predictor.get_output_names()
+ for output_name in output_names:
+ output_tensor = predictor.get_output_handle(output_name)
+ output_tensors.append(output_tensor)
+ return predictor, config, input_names, input_tensors, output_tensors
+
+ def get_input_names(self):
+ return self.input_names
+
+ def run(self, x):
+ if not isinstance(x, (list, tuple)):
+ x = [x]
+
+ for idx in range(len(x)):
+ self.input_tensors[idx].copy_from_cpu(x[idx])
+ self.predictor.run()
+ result = []
+
+ self.predictor.run()
+ output_names = self.predictor.get_output_names()
+ for name in output_names:
+ output = self.predictor.get_output_handle(name).copy_to_cpu()
+ result.append(output)
+ return result
diff --git a/paddlecv/ppcv/utils/__init__.py b/paddlecv/ppcv/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47
--- /dev/null
+++ b/paddlecv/ppcv/utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlecv/ppcv/utils/download.py b/paddlecv/ppcv/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a1ceb885c71331eee66946519bae39ed149261
--- /dev/null
+++ b/paddlecv/ppcv/utils/download.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import sys
+import yaml
+import time
+import shutil
+import requests
+import tqdm
+import hashlib
+import base64
+import binascii
+import tarfile
+import zipfile
+
+from .logger import setup_logger
+
+logger = setup_logger(__name__)
+
+__all__ = [
+ 'get_model_path',
+ 'get_config_path',
+ 'get_dict_path',
+]
+
+WEIGHTS_HOME = osp.expanduser("~/.cache/paddlecv/models")
+CONFIGS_HOME = osp.expanduser("~/.cache/paddlecv/configs")
+DICTS_HOME = osp.expanduser("~/.cache/paddlecv/dicts")
+FONTS_HOME = osp.expanduser("~/.cache/paddlecv/fonts")
+
+# dict of {dataset_name: (download_info, sub_dirs)}
+# download info: [(url, md5sum)]
+
+DOWNLOAD_RETRY_LIMIT = 3
+
+PMP_DOWNLOAD_URL_PREFIX = 'https://bj.bcebos.com/v1/paddle-model-ecology/paddlecv/'
+
+
+def is_url(path):
+ """
+ Whether path is URL.
+ Args:
+ path (string): URL string or not.
+ """
+ return path.startswith('http://') \
+ or path.startswith('https://') \
+ or path.startswith('paddlecv://')
+
+
+def parse_url(url):
+ url = url.replace("paddlecv://", PMP_DOWNLOAD_URL_PREFIX)
+ return url
+
+
+def get_model_path(path):
+ """Get model path from WEIGHTS_HOME, if not exists,
+ download it from url.
+ """
+ if not is_url(path):
+ return path
+ url = parse_url(path)
+ path, _ = get_path(url, WEIGHTS_HOME, path_depth=2)
+ logger.info("The model path is {}".format(path))
+ return path
+
+
+def get_config_path(path):
+ """Get config path from CONFIGS_HOME, if not exists,
+ download it from url.
+ """
+ if not is_url(path):
+ return path
+ url = parse_url(path)
+ path, _ = get_path(url, CONFIGS_HOME)
+ logger.info("The config path is {}".format(path))
+ return path
+
+
+def get_dict_path(path):
+ """Get dict path from DICTS_HOME, if not exists,
+ download it from url.
+ """
+ if not is_url(path):
+ return path
+ url = parse_url(path)
+ path, _ = get_path(url, DICTS_HOME)
+ logger.info("The dict path is {}".format(path))
+ return path
+
+
+def get_font_path(path):
+ """Get config path from CONFIGS_HOME, if not exists,
+ download it from url.
+ """
+ if not is_url(path):
+ return path
+ url = parse_url(path)
+ path, _ = get_path(url, FONTS_HOME)
+ return path
+
+
+def map_path(url, root_dir, path_depth=1):
+ # parse path after download to decompress under root_dir
+ assert path_depth > 0, "path_depth should be a positive integer"
+ dirname = url
+ for _ in range(path_depth):
+ dirname = osp.dirname(dirname)
+ fpath = osp.relpath(url, dirname)
+ path = osp.join(root_dir, fpath)
+ dirname = osp.dirname(path)
+ return path, dirname
+
+
+def get_path(url, root_dir, md5sum=None, check_exist=True, path_depth=1):
+ """ Download from given url to root_dir.
+ if file or directory specified by url is exists under
+ root_dir, return the path directly, otherwise download
+ from url, return the path.
+ url (str): download url
+ root_dir (str): root dir for downloading, it should be
+ WEIGHTS_HOME
+ md5sum (str): md5 sum of download package
+ """
+ # parse path after download to decompress under root_dir
+ fullpath, dirname = map_path(url, root_dir, path_depth)
+
+ if osp.exists(fullpath) and check_exist:
+ if not osp.isfile(fullpath) or \
+ _check_exist_file_md5(fullpath, md5sum, url):
+ logger.debug("Found {}".format(fullpath))
+ return fullpath, True
+ else:
+ os.remove(fullpath)
+
+ fullname = _download(url, dirname, md5sum)
+ return fullpath, False
+
+
+def _download(url, path, md5sum=None):
+ """
+ Download from url, save to path.
+ url (str): download url
+ path (str): download to given path
+ """
+ if not osp.exists(path):
+ os.makedirs(path)
+
+ fname = osp.split(url)[-1]
+ fullname = osp.join(path, fname)
+ retry_cnt = 0
+
+ while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
+ url)):
+ if retry_cnt < DOWNLOAD_RETRY_LIMIT:
+ retry_cnt += 1
+ else:
+ raise RuntimeError("Download from {} failed. "
+ "Retry limit reached".format(url))
+
+ logger.info("Downloading {} from {}".format(fname, url))
+
+ # NOTE: windows path join may incur \, which is invalid in url
+ if sys.platform == "win32":
+ url = url.replace('\\', '/')
+
+ req = requests.get(url, stream=True)
+ if req.status_code != 200:
+ raise RuntimeError("Downloading from {} failed with code "
+ "{}!".format(url, req.status_code))
+
+ # For protecting download interupted, download to
+ # tmp_fullname firstly, move tmp_fullname to fullname
+ # after download finished
+ tmp_fullname = fullname + "_tmp"
+ total_size = req.headers.get('content-length')
+ with open(tmp_fullname, 'wb') as f:
+ if total_size:
+ for chunk in tqdm.tqdm(
+ req.iter_content(chunk_size=1024),
+ total=(int(total_size) + 1023) // 1024,
+ unit='KB'):
+ f.write(chunk)
+ else:
+ for chunk in req.iter_content(chunk_size=1024):
+ if chunk:
+ f.write(chunk)
+ shutil.move(tmp_fullname, fullname)
+ return fullname
+
+
+def _check_exist_file_md5(filename, md5sum, url):
+ # if md5sum is None, and file to check is model file,
+ # read md5um from url and check, else check md5sum directly
+ return _md5check_from_url(filename, url) if md5sum is None \
+ and filename.endswith('pdparams') \
+ else _md5check(filename, md5sum)
+
+
+def _md5check_from_url(filename, url):
+ # For model in bcebos URLs, MD5 value is contained
+ # in request header as 'content_md5'
+ req = requests.get(url, stream=True)
+ content_md5 = req.headers.get('content-md5')
+ req.close()
+ if not content_md5 or _md5check(
+ filename,
+ binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
+ )):
+ return True
+ else:
+ return False
+
+
+def _md5check(fullname, md5sum=None):
+ if md5sum is None:
+ return True
+
+ logger.debug("File {} md5 checking...".format(fullname))
+ md5 = hashlib.md5()
+ with open(fullname, 'rb') as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ md5.update(chunk)
+ calc_md5sum = md5.hexdigest()
+
+ if calc_md5sum != md5sum:
+ logger.warning("File {} md5 check failed, {}(calc) != "
+ "{}(base)".format(fullname, calc_md5sum, md5sum))
+ return False
+ return True
diff --git a/paddlecv/ppcv/utils/helper.py b/paddlecv/ppcv/utils/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec947e3f55b629b355f41fe56eb2d30dc27a5144
--- /dev/null
+++ b/paddlecv/ppcv/utils/helper.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import math
+import glob
+
+import ppcv
+from ppcv.ops import *
+from ppcv.core.workspace import get_global_op
+
+
+def get_output_keys(cfg=None):
+ op_list = get_global_op()
+ if cfg is None:
+ output = dict()
+ for name, op in op_list.items():
+ if op.type() != 'OUTPUT':
+ keys = op.get_output_keys()
+ output.update({name: keys})
+ else:
+ output = {'input.image', 'input.video'}
+ for op in cfg:
+ op_arch = op_list[list(op.keys())[0]]
+ op_cfg = list(op.values())[0]
+ if op_arch.type() == 'OUTPUT': continue
+ for out_name in op_arch.get_output_keys():
+ name = op_cfg['name'] + '.' + out_name
+ output.add(name)
+ return output
+
+
+def gen_input_name(input_keys, last_ops, output_keys):
+ # generate input name according to input_keys and last_ops
+ # the name format is {last_ops}.{input_key}
+ input_name = list()
+ for key in input_keys:
+ found = False
+ if key in output_keys:
+ found = True
+ input_name.append(key)
+ else:
+ for op in last_ops:
+ name = op + '.' + key
+ if name in input_name:
+ raise ValueError("Repeat input: {}".format(name))
+ if name in output_keys:
+ input_name.append(name)
+ found = True
+ break
+ if not found:
+ raise ValueError(
+ "Input: {} could not be found from the last ops: {}. The outputs of these last ops are {}".
+ format(key, last_ops, output_keys))
+ return input_name
diff --git a/paddlecv/ppcv/utils/logger.py b/paddlecv/ppcv/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..50b66265209b27fe18deb8d7688e3632c6ba99a0
--- /dev/null
+++ b/paddlecv/ppcv/utils/logger.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+
+import paddle.distributed as dist
+
+__all__ = ['setup_logger']
+
+logger_initialized = []
+
+
+def setup_logger(name="ppcv", output=None):
+ """
+ Initialize logger and set its verbosity level to INFO.
+ Args:
+ name (str): the root module name of this logger
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+
+ logger.setLevel(logging.INFO)
+ logger.propagate = False
+
+ formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
+ datefmt="%m/%d %H:%M:%S")
+ # stdout logging: master only
+ local_rank = dist.get_rank()
+ if local_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if local_rank > 0:
+ filename = filename + ".rank{}".format(local_rank)
+ os.makedirs(os.path.dirname(filename))
+ fh = logging.FileHandler(filename, mode='a')
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(logging.Formatter())
+ logger.addHandler(fh)
+ logger_initialized.append(name)
+ return logger
diff --git a/paddlecv/ppcv/utils/timer.py b/paddlecv/ppcv/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5be5b983c0f8a8ce886180d5dcbe887da89c7df
--- /dev/null
+++ b/paddlecv/ppcv/utils/timer.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+import os
+import ast
+import glob
+import yaml
+import copy
+import numpy as np
+
+
+class Times(object):
+ def __init__(self):
+ self.time = 0.
+ # start time
+ self.st = 0.
+ # end time
+ self.et = 0.
+
+ def start(self):
+ self.st = time.time()
+
+ def end(self, repeats=1, accumulative=True):
+ self.et = time.time()
+ if accumulative:
+ self.time += (self.et - self.st) / repeats
+ else:
+ self.time = (self.et - self.st) / repeats
+
+ def reset(self):
+ self.time = 0.
+ self.st = 0.
+ self.et = 0.
+
+ def value(self):
+ return round(self.time, 4)
+
+
+class PipeTimer(Times):
+ def __init__(self, cfg):
+ super(PipeTimer, self).__init__()
+ self.total_time = Times()
+ self.module_time = dict()
+ for op in cfg:
+ op_name = op.values()['name']
+ self.module_time.update({op_name: Times()})
+
+ self.img_num = 0
+
+ def get_total_time(self):
+ total_time = self.total_time.value()
+ average_latency = total_time / max(1, self.img_num)
+ qps = 0
+ if total_time > 0:
+ qps = 1 / average_latency
+ return total_time, average_latency, qps
+
+ def info(self):
+ total_time, average_latency, qps = self.get_total_time()
+ print("------------------ Inference Time Info ----------------------")
+ print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
+ self.img_num))
+
+ for k, v in self.module_time.items():
+ v_time = round(v.value(), 4)
+ if v_time > 0:
+ print("{} time(ms): {}; per frame average time(ms): {}".format(
+ k, v_time * 1000, v_time * 1000 / self.img_num))
+ print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
+ average_latency * 1000, qps))
+ return qps
+
+ def report(self, average=False):
+ dic = {}
+ for m, time in self.module_time:
+ dic[m] = round(time.value() / max(1, self.img_num),
+ 4) if average else time.value()
+ dic['total'] = round(self.total_time.value() / max(1, self.img_num),
+ 4) if average else self.total_time.value()
+ dic['img_num'] = self.img_num
+ return dic
diff --git a/paddlecv/requirements.txt b/paddlecv/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a8c8fcd9379815a432c462acc151c633bb67e995
--- /dev/null
+++ b/paddlecv/requirements.txt
@@ -0,0 +1,6 @@
+scipy>=1.0.0
+opencv-python
+opencv-contrib-python
+PyYAML>=5.1
+Pillow
+faiss-cpu==1.7.1.post2
diff --git a/paddlecv/scripts/build_wheel.sh b/paddlecv/scripts/build_wheel.sh
new file mode 100644
index 0000000000000000000000000000000000000000..66725834a7e1ca55a6a34280bc20a851f5460238
--- /dev/null
+++ b/paddlecv/scripts/build_wheel.sh
@@ -0,0 +1,155 @@
+#!/usr/bin/env bash
+
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#=================================================
+# Utils
+#=================================================
+
+
+# directory config
+DIST_DIR="dist"
+BUILD_DIR="build"
+EGG_DIR="paddlecv.egg-info"
+
+CFG_DIR="configs"
+TEST_DIR="unittests"
+DATA_DIR="demo"
+
+# command line log config
+RED='\033[0;31m'
+BLUE='\033[0;34m'
+GREEN='\033[1;32m'
+BOLD='\033[1m'
+NONE='\033[0m'
+
+function python_version_check() {
+ PY_MAIN_VERSION=`python -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'`
+ PY_SUB_VERSION=`python -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $2}'`
+ echo -e "find python version ${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
+ if [ $PY_MAIN_VERSION -ne "3" -o $PY_SUB_VERSION -lt "5" ]; then
+ echo -e "${RED}FAIL:${NONE} please use Python >= 3.5 !"
+ exit 1
+ fi
+}
+
+function init() {
+ echo -e "${BLUE}[init]${NONE} removing building directory..."
+ rm -rf $DIST_DIR $BUILD_DIR $EGG_DIR $TEST_DIR
+ if [ `pip list | grep paddlecv | wc -l` -gt 0 ]; then
+ echo -e "${BLUE}[init]${NONE} uninstalling paddlecv..."
+ #pip uninstall -y paddlecv
+ fi
+ echo -e "${BLUE}[init]${NONE} ${GREEN}init success\n"
+}
+
+function build_and_install() {
+ echo -e "${BLUE}[build]${NONE} building paddlecv wheel..."
+ python setup.py sdist bdist_wheel
+ if [ $? -ne 0 ]; then
+ echo -e "${RED}[FAIL]${NONE} build paddlecv wheel failed !"
+ exit 1
+ fi
+ echo -e "${BLUE}[build]${NONE} ${GREEN}build paddlecv wheel success\n"
+
+ echo -e "${BLUE}[install]${NONE} installing paddlecv..."
+ cd $DIST_DIR
+ find . -name "paddlecv*.whl" | xargs pip install
+ if [ $? -ne 0 ]; then
+ cd ..
+ echo -e "${RED}[FAIL]${NONE} install paddlecv wheel failed !"
+ exit 1
+ fi
+ echo -e "${BLUE}[install]${NONE} ${GREEN}paddlecv install success\n"
+ cd ..
+}
+
+function unittest() {
+ if [ -d $TEST_DIR ]; then
+ rm -rf $TEST_DIR
+ fi;
+
+ echo -e "${BLUE}[unittest]${NONE} run unittests..."
+
+ # NOTE: perform unittests under TEST_DIR to
+ # make sure installed paddlecv is used
+ mkdir $TEST_DIR
+ cp -r $CFG_DIR $TEST_DIR
+ cp -r $DATA_DIR $TEST_DIR
+ cd $TEST_DIR
+ if [ $? != 0 ]; then
+ exit 1
+ fi
+ find "../" -wholename '*tests/test_*' -type f -print0 | \
+ xargs -0 -I{} -n1 -t bash -c 'python -u -s {}'
+
+ # clean TEST_DIR
+ cd ..
+ rm -rf $TEST_DIR
+ echo -e "${BLUE}[unittest]${NONE} ${GREEN}unittests success\n${NONE}"
+}
+
+function cleanup() {
+ if [ -d $TEST_DIR ]; then
+ rm -rf $TEST_DIR
+ fi
+
+ rm -rf $BUILD_DIR $EGG_DIR
+ pip uninstall -y paddlecv
+}
+
+function abort() {
+ echo -e "${RED}[FAIL]${NONE} build wheel and unittest failed !
+ please check your code" 1>&2
+
+ cur_dir=`basename "$pwd"`
+ if [ cur_dir==$TEST_DIR -o cur_dir==$DIST_DIR ]; then
+ cd ..
+ fi
+
+ rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR $TEST_DIR
+ pip uninstall -y paddlecv
+}
+
+python_version_check
+
+trap 'abort' 0
+set -e
+
+init
+build_and_install
+unittest
+cleanup
+
+# get Paddle version
+PADDLE_VERSION=`python -c "import paddle; print(paddle.version.full_version)"`
+PADDLE_COMMIT=`python -c "import paddle; print(paddle.version.commit)"`
+PADDLE_COMMIT=`git rev-parse --short $PADDLE_COMMIT`
+
+# get PaddleDetection branch
+PPCV_BRANCH=`git rev-parse --abbrev-ref HEAD`
+PPCV_COMMIT=`git rev-parse --short HEAD`
+
+# get Python version
+PYTHON_VERSION=`python -c "import platform; print(platform.python_version())"`
+
+echo -e "\n${GREEN}paddlecv wheel compiled and checked success !${NONE}
+ ${BLUE}Python version:${NONE} $PYTHON_VERSION
+ ${BLUE}Paddle version:${NONE} $PADDLE_VERSION ($PADDLE_COMMIT)
+ ${BLUE}PaddleCV branch:${NONE} $PPCV_BRANCH ($PPCV_COMMIT)\n"
+
+echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist"
+
+trap : 0
diff --git a/paddlecv/setup.py b/paddlecv/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..450611e98745ff861ee5ec6df8e4349855370ba1
--- /dev/null
+++ b/paddlecv/setup.py
@@ -0,0 +1,106 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import glob
+import setuptools
+from setuptools import setup, find_packages
+
+VERSION = '0.1.0'
+
+with open('requirements.txt', encoding="utf-8-sig") as f:
+ requirements = f.readlines()
+
+
+def readme():
+ with open('./README.md', encoding="utf-8-sig") as f:
+ README = f.read()
+ return README
+
+
+def get_package_data_files(package, data, package_dir=None):
+ """
+ Helps to list all specified files in package including files in directories
+ since `package_data` ignores directories.
+ """
+ if package_dir is None:
+ package_dir = os.path.join(*package.split('.'))
+ all_files = []
+ for f in data:
+ path = os.path.join(package_dir, f)
+ if os.path.isfile(path):
+ all_files.append(f)
+ continue
+ for root, _dirs, files in os.walk(path, followlinks=True):
+ root = os.path.relpath(root, package_dir)
+ for file in files:
+ file = os.path.join(root, file)
+ if file not in all_files:
+ all_files.append(file)
+ return all_files
+
+
+def get_package_model_zoo():
+ cur_dir = osp.dirname(osp.realpath(__file__))
+ cfg_dir = osp.join(cur_dir, "configs")
+ cfgs = glob.glob(osp.join(cfg_dir, '*/*.yml'))
+
+ valid_cfgs = []
+ for cfg in cfgs:
+ # exclude dataset base config
+ if osp.split(osp.split(cfg)[0])[1] not in ['unittest']:
+ valid_cfgs.append(cfg)
+ model_names = [
+ osp.relpath(cfg, cfg_dir).replace(".yml", "") for cfg in valid_cfgs
+ ]
+
+ model_zoo_file = osp.join(cur_dir, 'ppcv', 'model_zoo', 'MODEL_ZOO')
+ with open(model_zoo_file, 'w') as wf:
+ for model_name in model_names:
+ wf.write("{}\n".format(model_name))
+
+ return [model_zoo_file]
+
+
+setup(
+ name='paddlecv',
+ packages=['paddlecv'],
+ package_dir={'paddlecv': ''},
+ package_data={
+ 'configs': get_package_data_files('configs', ['unittest', ]),
+ 'ppcv.model_zoo': get_package_model_zoo(),
+ },
+ include_package_data=True,
+ version=VERSION,
+ install_requires=requirements,
+ license='Apache License 2.0',
+ description='A tool for building model pipeline powered by PaddlePaddle.',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ url='https://github.com/PaddlePaddle/models',
+ download_url='https://github.com/PaddlePaddle/models.git',
+ keywords=['paddle-model-pipeline', 'PP-OCR', 'PP-ShiTu', 'PP-Human'],
+ classifiers=[
+ 'Intended Audience :: Developers',
+ 'Operating System :: OS Independent',
+ 'Natural Language :: Chinese (Simplified)',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
+ 'Topic :: Utilities',
+ ], )
diff --git a/paddlecv/tests/__init__.py b/paddlecv/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..97043fd7ba6885aac81cad5a49924c23c67d4d47
--- /dev/null
+++ b/paddlecv/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlecv/tests/test_classification.py b/paddlecv/tests/test_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..150f7167a57c51b57fb4bf0357597c00693300bb
--- /dev/null
+++ b/paddlecv/tests/test_classification.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import cv2
+import unittest
+import yaml
+import argparse
+
+from ppcv.core.workspace import global_config
+from ppcv.core.config import ConfigParser
+
+
+class TestClassification(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_classification.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_classification(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": [img]
+ },
+ {
+ "input.image": [img, img]
+ },
+ ]
+ op_name = list(self.model_cfg[0].keys())[0]
+ cls_op = global_config[op_name](self.model_cfg[0][op_name],
+ self.env_cfg)
+ result = cls_op(inputs)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_connector.py b/paddlecv/tests/test_connector.py
new file mode 100644
index 0000000000000000000000000000000000000000..034cf5aacbe52bc0632b4db80f6827d420c09e22
--- /dev/null
+++ b/paddlecv/tests/test_connector.py
@@ -0,0 +1,337 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import numpy as np
+import cv2
+import copy
+import unittest
+import yaml
+import argparse
+
+from ppcv.core.workspace import global_config
+from ppcv.core.config import ConfigParser
+
+
+class TestClsCorrection(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_cls_connector.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_cls_correction(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": [img],
+ "input.class_ids": [[0], ],
+ "input.scores": [[0.95]],
+ },
+ {
+ "input.image": [img, img],
+ "input.class_ids": [[1], [3]],
+ "input.scores": [[0.95], [0.85]],
+ },
+ ]
+ op_name = "ClsCorrectionOp"
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+ self.assert_equal(img, result)
+
+ def assert_equal(self, img, result):
+ diff = np.sum(np.abs(img - result[0][0]))
+ self.assertEqual(diff, 0)
+
+ corr_img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
+ diff = np.sum(np.abs(corr_img - result[1][0]))
+ self.assertEqual(diff, 0)
+
+ diff = np.sum(np.abs(img - result[1][1]))
+ self.assertEqual(diff, 0)
+
+
+class TestBboxCropOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_bbox_crop.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_bbox_crop(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": img,
+ "input.bbox": np.array([[1, 1, 2, 5]]),
+ },
+ {
+ "input.image": img,
+ "input.bbox": np.array([[1, 1, 10, 20]]),
+ },
+ ]
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+ self.assert_equal(result)
+
+ def assert_equal(self, result):
+ gt_res = 55398.0
+ sums = 0
+ for idx, r in enumerate(result):
+ for k, poly in r.items():
+ sums += np.sum(np.abs(poly))
+ self.assertEqual(gt_res, sums)
+
+
+class TestPolyCropOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_poly_crop.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_poly_crop(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": img,
+ "input.poly": np.array(
+ [[[1, 1], [2, 1], [2, 3], [1, 3]]]).astype(np.float32),
+ },
+ {
+ "input.image": img,
+ "input.poly": np.array(
+ [[[1, 1], [2, 1], [2, 3], [0, 3]]]).astype(np.float32),
+ },
+ ]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+ self.assert_equal(result)
+
+ def assert_equal(self, result):
+ gt_res = 3620.0
+ sums = 0
+ for idx, r in enumerate(result):
+ for poly in r['poly_crop.crop_image']:
+ sums += np.sum(np.abs(poly))
+ self.assertEqual(gt_res, sums)
+
+
+class TestFragmentCompositionOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_fragment_composition.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_fragment_composition(self):
+ inputs = [
+ {
+ "input.text": ["hello", "world"]
+ },
+ {
+ "input.text": ["paddle", "paddle"]
+ },
+ ]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+
+ self.assert_equal(result)
+
+ def assert_equal(self, result):
+ gt_res = ["hello world", "paddle paddle"]
+ for idx, r in enumerate(result):
+ self.assertEqual(gt_res[idx], r)
+
+
+class TestKeyFrameExtractionOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_key_frame_extraction.yml'
+ self.input = 'demo/pikachu.mp4'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_key_frame_extraction(self):
+ inputs = [{"input.video_path": "demo/pikachu.mp4", }]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+
+ self.assert_equal(result)
+
+ def assert_equal(self, result):
+ key_frames_id = [
+ 7, 80, 102, 139, 200, 234, 271, 320, 378, 437, 509, 592, 619, 684,
+ 749, 791, 843, 872, 934, 976, 1028, 1063, 1156, 1179, 1249, 1356,
+ 1400, 1461, 1516, 1630, 1668, 1718, 1768
+ ]
+ gt_abs_sum = 2312581162.0
+ abs_sum = sum([np.sum(np.abs(k)) for k in result[0][0]])
+ self.assertEqual(result[0][1], key_frames_id)
+ self.assertAlmostEqual(gt_abs_sum, abs_sum)
+
+
+class TestTableMatcherOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_table_matcher.yml'
+ self.input = './demo/table_demo.npy'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_table_matcher(self):
+ inputs = np.load(self.input, allow_pickle=True).item()
+
+ gt_res = inputs['outputs']
+ inputs = inputs['inputs']
+ pipe_inputs = [{}]
+ for k, v in inputs[0].items():
+ pipe_inputs[0].update({'input.' + k: v})
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(pipe_inputs)
+
+ self.assertEqual(gt_res[0]['Matcher.html'],
+ result[0]['tablematcher.html'])
+
+
+class TestPPStructureFilterOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_ppstructure_filter.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_ppstructure_filter(self):
+ inputs = [
+ {
+ "input.dt_cls_names": ['table', 'txt'],
+ "input.crop_image": ['', '1'],
+ "input.dt_polys": [1, 0],
+ "input.rec_text": ['a', 'b']
+ },
+ {
+ "input.dt_cls_names": ['table', 'txt'],
+ "input.crop_image": ['1', ''],
+ "input.dt_polys": [0, 1],
+ "input.rec_text": ['b', 'a']
+ },
+ ]
+ gt_res = [
+ {
+ "filter.image": [''],
+ "filter.dt_polys": [1],
+ "filter.rec_text": ['a']
+ },
+ {
+ "filter.image": ['1'],
+ "filter.dt_polys": [0],
+ "filter.rec_text": ['b']
+ },
+ ]
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+
+ self.assertEqual(gt_res, result)
+
+
+class TestPPStructureResultConcatOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_ppstructure_result_concat.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_ppstructure_result_concat(self):
+ inputs = [
+ {
+ "input.table.structures": [['td'], ['td']],
+ "input.Matcher.html": ['html', 'html'],
+ "input.layout.dt_bboxes": [0, 1, 2, 3, 4],
+ "input.table.dt_bboxes": [[0], [1]],
+ "input.filter_table.dt_polys": [[0], [1]],
+ "input.filter_table.rec_text": [['0'], ['1']],
+ "input.filter_txts.dt_polys": [[2], [3], [4]],
+ "input.filter_txts.rec_text": [['2'], ['3'], ['4']],
+ },
+ {
+ "input.table.structures": [['td'], ['td']],
+ "input.Matcher.html": ['html', 'html'],
+ "input.layout.dt_bboxes": [5, 1, 2, 3, 4],
+ "input.table.dt_bboxes": [[5], [1]],
+ "input.filter_table.dt_polys": [[5], [1]],
+ "input.filter_table.rec_text": [['5'], ['1']],
+ "input.filter_txts.dt_polys": [[2], [3], [4]],
+ "input.filter_txts.rec_text": [['2'], ['3'], ['4']],
+ },
+ ]
+ gt_res = [{
+ 'concat.dt_polys': [[2], [3], [4], [0], [1]],
+ 'concat.rec_text': [['2'], ['3'], ['4'], ['0'], ['1']],
+ 'concat.dt_bboxes': [0, 1, 2, 3, 4],
+ 'concat.html': ['', '', '', 'html', 'html'],
+ 'concat.cell_bbox': [[], [], [], [0], [1]],
+ 'concat.structures': [[], [], [], ['td'], ['td']]
+ }, {
+ 'concat.dt_polys': [[2], [3], [4], [5], [1]],
+ 'concat.rec_text': [['2'], ['3'], ['4'], ['5'], ['1']],
+ 'concat.dt_bboxes': [5, 1, 2, 3, 4],
+ 'concat.html': ['', '', '', 'html', 'html'],
+ 'concat.cell_bbox': [[], [], [], [5], [1]],
+ 'concat.structures': [[], [], [], ['td'], ['td']]
+ }]
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name])
+ result = op(inputs)
+
+ self.assertEqual(gt_res, result)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_custom_op.py b/paddlecv/tests/test_custom_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ec537dcaeed948f19e2db883191ea2aef567fa
--- /dev/null
+++ b/paddlecv/tests/test_custom_op.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import cv2
+import unittest
+import yaml
+import argparse
+
+import numpy as np
+import math
+import paddle
+from ppcv.engine.pipeline import Pipeline
+from ppcv.ops.models.base import ModelBaseOp
+from ppcv.core.workspace import register
+from ppcv.core.config import ConfigParser
+
+
+@register
+class BlankOp(ModelBaseOp):
+ def __init__(self, model_cfg, env_cfg):
+ super(BlankOp, self).__init__(model_cfg, env_cfg)
+
+ @classmethod
+ def get_output_keys(cls):
+ return ['blank_output']
+
+ def __call__(self, inputs):
+ output = []
+ for input in inputs:
+ output.append({self.output_keys[0]: input})
+ return output
+
+
+class TestCustomOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_custom_op.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+
+ def test_custom_op(self):
+ cfg = argparse.Namespace(**self.cfg_dict)
+ input = os.path.abspath(self.input)
+ pipeline = Pipeline(cfg)
+ pipeline.run(input)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_detection.py b/paddlecv/tests/test_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f2307f7405610779b97cc226968f5970ccc7659
--- /dev/null
+++ b/paddlecv/tests/test_detection.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import cv2
+import unittest
+import yaml
+import argparse
+
+from ppcv.core.workspace import global_config
+from ppcv.core.config import ConfigParser
+
+
+class TestDetection(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_detection.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_detection(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": img,
+ },
+ {
+ "input.image": img,
+ },
+ ]
+ op_name = list(self.model_cfg[0].keys())[0]
+ det_op = global_config[op_name](self.model_cfg[0][op_name],
+ self.env_cfg)
+ result = det_op(inputs)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_get_model.py b/paddlecv/tests/test_get_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a441e9f0169302d2ace781bc136e9c296b05be63
--- /dev/null
+++ b/paddlecv/tests/test_get_model.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import os
+import paddle
+import ppcv
+import unittest
+
+TASK_NAME = 'detection'
+MODEL_NAME = 'paddlecv://models/PPLCNet_x1_0_infer/inference.pdiparams'
+
+
+class TestGetConfigFile(unittest.TestCase):
+ def test_main(self):
+ try:
+ cfg_file = ppcv.model_zoo.get_config_file(TASK_NAME)
+ model_file = ppcv.model_zoo.get_model_file(MODEL_NAME)
+ assert os.path.isfile(cfg_file)
+ assert os.path.isfile(model_file)
+ except:
+ self.assertTrue(False)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_keypoint.py b/paddlecv/tests/test_keypoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac4a16627be7e49e4a7bb0064cfb5952fdff0e4
--- /dev/null
+++ b/paddlecv/tests/test_keypoint.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import cv2
+import unittest
+import yaml
+import argparse
+
+from ppcv.ops import KeypointOp
+from ppcv.core.config import ConfigParser
+
+
+class TestKeypoint(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_keypoint.yml'
+ self.input = 'demo/hrnet_demo.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_detection(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [{"input.image": [img, img, img]}, ]
+ kpt_op = KeypointOp(self.model_cfg[0]["KeypointOp"], self.env_cfg)
+
+ result = kpt_op(inputs)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_list_model.py b/paddlecv/tests/test_list_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81055fe58d2c416b492e35f694bfdd9d9413ef3
--- /dev/null
+++ b/paddlecv/tests/test_list_model.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import unittest
+import ppcv
+
+
+class TestListModel(unittest.TestCase):
+ def setUp(self):
+ self._filter = []
+
+ def test_main(self):
+ try:
+ ppcv.model_zoo.list_model(self._filter)
+ self.assertTrue(True)
+ except:
+ self.assertTrue(False)
+
+
+class TestListModelClas(TestListModel):
+ def setUp(self):
+ self._filter = ['system']
+
+
+class TestListModelPPLCNet(TestListModel):
+ def setUp(self):
+ self._filter = ['PP-LCNet']
+
+
+class TestListModelError(unittest.TestCase):
+ def setUp(self):
+ self._filter = ['xxx']
+
+ def test_main(self):
+ try:
+ ppcv.model_zoo.list_model(self._filter)
+ self.assertTrue(False)
+ except ValueError:
+ self.assertTrue(True)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_ocr.py b/paddlecv/tests/test_ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..283cd3f17102d2b67c816ba1f55112f0e7be2caf
--- /dev/null
+++ b/paddlecv/tests/test_ocr.py
@@ -0,0 +1,200 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import cv2
+import unittest
+import yaml
+import argparse
+
+from ppcv.core.workspace import global_config
+from ppcv.core.config import ConfigParser
+from ppcv.engine.pipeline import Pipeline
+
+
+class TestOcrDbDetOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_ocr_db_det.yml'
+ self.input = 'demo/00056221.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_detection(self):
+ img = cv2.imread(self.input)
+ inputs = [{"input.image": img, }]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name], self.env_cfg)
+ result = op(inputs)
+
+
+class TestOcrCrnnRecOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_ocr_crnn_rec.yml'
+ self.input = 'demo/word_1.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_recognition(self):
+ img = cv2.imread(self.input)
+ inputs = [
+ {
+ "input.image": [img, img],
+ },
+ {
+ "input.image": [img, img, img],
+ },
+ {
+ "input.image": [img],
+ },
+ ]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name], self.env_cfg)
+ result = op(inputs)
+
+
+class TestPPOCRv2(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/system/PP-OCRv2.yml'
+ self.input = 'demo/00056221.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+
+ def test_ppocrv2(self):
+ cfg = argparse.Namespace(**self.cfg_dict)
+ input = os.path.abspath(self.input)
+ pipeline = Pipeline(cfg)
+ pipeline.run(input)
+
+
+class TestPPOCRv3(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/system/PP-OCRv3.yml'
+ self.input = 'demo/00056221.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+
+ def test_ppocrv3(self):
+ cfg = argparse.Namespace(**self.cfg_dict)
+ input = os.path.abspath(self.input)
+ pipeline = Pipeline(cfg)
+ pipeline.run(input)
+
+
+class TestPPStructureTableStructureOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_ocr_table_structure.yml'
+ self.input = 'demo/table.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_table_structure(self):
+ img = cv2.imread(self.input)
+ inputs = [{"input.image": img, }]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name], self.env_cfg)
+ result = op(inputs)
+
+
+class TestPPStructuretable(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/system/PP-Structure-table.yml'
+ self.input = 'demo/table.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+
+ def test_structure_table(self):
+ cfg = argparse.Namespace(**self.cfg_dict)
+ input = os.path.abspath(self.input)
+ pipeline = Pipeline(cfg)
+ pipeline.run(input)
+
+
+class TestLayoutDetectionOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_ocr_layout.yml'
+ self.input = 'demo/pp_structure_demo.png'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_layout_detection(self):
+ img = cv2.imread(self.input)
+ inputs = [{"input.image": img, }]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name], self.env_cfg)
+ result = op(inputs)
+
+
+class TestPPStructure(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/system/PP-Structure.yml'
+ self.input = 'demo/pp_structure_demo.png'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+
+ def test_ppstructure(self):
+ cfg = argparse.Namespace(**self.cfg_dict)
+ input = os.path.abspath(self.input)
+ pipeline = Pipeline(cfg)
+ pipeline.run(input)
+
+
+class PPStructureKieSerOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/system/PP-Structure-ser.yml'
+ self.input = 'demo/kie_demo.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_ppstructure_kie_ser(self):
+ img = cv2.imread(self.input)
+ inputs = [{"input.image": img, }]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name], self.env_cfg)
+ result = op(inputs)
+
+
+class PPStructureKieReOp(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/system/PP-Structure-re.yml'
+ self.input = 'demo/kie_demo.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_ppstructure_kie_re(self):
+ img = cv2.imread(self.input)
+ inputs = [{"input.image": img, }]
+ op_name = list(self.model_cfg[0].keys())[0]
+ op = global_config[op_name](self.model_cfg[0][op_name], self.env_cfg)
+ result = op(inputs)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_pipeline.py b/paddlecv/tests/test_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..294c3f4e94a2a9b2ae133a1b2f96c9a6a2078cff
--- /dev/null
+++ b/paddlecv/tests/test_pipeline.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import unittest
+from ppcv.engine.pipeline import Pipeline
+import yaml
+import argparse
+
+
+class TestPipeline(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_pipeline.yml'
+ self.input = 'demo/ILSVRC2012_val_00020010.jpeg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+
+ def test_pipeline(self):
+ cfg = argparse.Namespace(**self.cfg_dict)
+ input = os.path.abspath(self.input)
+ pipeline = Pipeline(cfg)
+ pipeline.run(input)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tests/test_segmentation.py b/paddlecv/tests/test_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3ae8a0e7854788c4b03a0a2e3d0275283e323ae
--- /dev/null
+++ b/paddlecv/tests/test_segmentation.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+parent = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath(os.path.join(parent, '../')))
+
+import cv2
+import unittest
+import yaml
+import argparse
+
+from ppcv.core.workspace import global_config
+from ppcv.core.config import ConfigParser
+
+
+class TestSegmentation(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_segmentation.yml'
+ self.input = 'demo/segmentation_cityscapes_img.png'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_segmentation(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": [img]
+ },
+ {
+ "input.image": [img, img]
+ },
+ ]
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ cls_op = global_config[op_name](self.model_cfg[0][op_name],
+ self.env_cfg)
+ result = cls_op(inputs)
+
+
+class TestPPHumanSegV2(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_seg_pphumansegv2.yml'
+ self.input = 'demo/pp_humansegv2_demo.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_segmentation(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": [img]
+ },
+ {
+ "input.image": [img, img]
+ },
+ ]
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ cls_op = global_config[op_name](self.model_cfg[0][op_name],
+ self.env_cfg)
+ result = cls_op(inputs)
+
+
+class TestPPMattingV1(unittest.TestCase):
+ def setUp(self):
+ self.config = 'configs/unittest/test_seg_ppmattingv1.yml'
+ self.input = 'demo/pp_mattingv1_demo.jpg'
+ self.cfg_dict = dict(config=self.config, input=self.input)
+ cfg = argparse.Namespace(**self.cfg_dict)
+ config = ConfigParser(cfg)
+ config.print_cfg()
+ self.model_cfg, self.env_cfg = config.parse()
+
+ def test_segmentation(self):
+ img = cv2.imread(self.input)[:, :, ::-1]
+ inputs = [
+ {
+ "input.image": [img]
+ },
+ {
+ "input.image": [img, img]
+ },
+ ]
+
+ op_name = list(self.model_cfg[0].keys())[0]
+ cls_op = global_config[op_name](self.model_cfg[0][op_name],
+ self.env_cfg)
+ result = cls_op(inputs)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlecv/tools/check_name.py b/paddlecv/tools/check_name.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef46d99eea1494e7b029c9f538992878519335b7
--- /dev/null
+++ b/paddlecv/tools/check_name.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import math
+import glob
+import paddle
+
+parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
+sys.path.insert(0, parent_path)
+
+from argparse import ArgumentParser
+import ppcv
+from ppcv.ops import *
+from ppcv.utils.helper import get_output_keys
+import yaml
+
+
+def argsparser():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--config", type=str, default=None, help=("Path of configure"))
+ return parser
+
+
+def check_cfg_output(cfg, output_dict):
+ with open(cfg) as f:
+ cfg = yaml.safe_load(f)
+ model_cfg = cfg['MODEL']
+ output_set = {'image', 'video', 'fn'}
+ for v in output_dict.values():
+ for name in v:
+ output_set.add(name)
+ for ops in model_cfg:
+ op_name = list(ops.keys())[0]
+ cfg_dict = list(ops.values())[0]
+ cfg_input = cfg_dict['Inputs']
+ for key in cfg_input:
+ key = key.split('.')[-1]
+ assert key in output_set, "Illegal input: {} in {}.".format(
+ key, op_name)
+
+
+def check_name(cfg):
+ config = None
+ config = vars(cfg)['config']
+ output_dict = get_output_keys()
+ buffer = yaml.dump(output_dict)
+ print('----------- Op output names ---------')
+ print(buffer)
+ if config is not None:
+ check_cfg_output(config, output_dict)
+
+
+if __name__ == '__main__':
+ parser = argsparser()
+ FLAGS = parser.parse_args()
+ check_name(FLAGS)
diff --git a/paddlecv/tools/predict.py b/paddlecv/tools/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..341ebf4fb51d48192354489388bc6111ac9fc501
--- /dev/null
+++ b/paddlecv/tools/predict.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import math
+import glob
+import paddle
+import cv2
+
+parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
+sys.path.insert(0, parent_path)
+
+from ppcv.engine.pipeline import Pipeline
+from ppcv.utils.logger import setup_logger
+from ppcv.core.config import ArgsParser
+
+
+def argsparser():
+ parser = ArgsParser()
+
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=None,
+ help=("Path of configure"),
+ required=True)
+ parser.add_argument(
+ "--input",
+ type=str,
+ default=None,
+ help="Path of input, suport image file, image directory and video file.",
+ required=True)
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="output",
+ help="Directory of output visualization files.")
+ parser.add_argument(
+ "--run_mode",
+ type=str,
+ default='paddle',
+ help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default='CPU',
+ help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
+ )
+ return parser
+
+
+if __name__ == '__main__':
+ parser = argsparser()
+ FLAGS = parser.parse_args()
+ input = os.path.abspath(FLAGS.input)
+ pipeline = Pipeline(FLAGS)
+ result = pipeline.run(input)